{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Plonk (
    Plonk (..)
) where

import           Data.Binary                                         (Binary)
import           Data.Functor.Rep                                    (Rep)
import           Data.Kind                                           (Type)
import           Data.Word                                           (Word8)
import           Prelude                                             hiding (Num (..), div, drop, length, replicate,
                                                                      sum, take, (!!), (/), (^))
import qualified Prelude                                             as P hiding (length)
import           Test.QuickCheck                                     (Arbitrary (..))

import           ZkFold.Base.Algebra.Basic.Class                     (AdditiveGroup)
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class             (EllipticCurve (..), Pairing, PointCompressed)
import           ZkFold.Base.Data.Vector                             (Vector (..))
import           ZkFold.Base.Protocol.NonInteractiveProof
import           ZkFold.Base.Protocol.Plonk.Prover                   (plonkProve)
import           ZkFold.Base.Protocol.Plonk.Verifier                 (plonkVerify)
import           ZkFold.Base.Protocol.Plonkup.Input
import           ZkFold.Base.Protocol.Plonkup.Internal
import           ZkFold.Base.Protocol.Plonkup.Proof
import           ZkFold.Base.Protocol.Plonkup.Prover
import           ZkFold.Base.Protocol.Plonkup.Verifier
import           ZkFold.Base.Protocol.Plonkup.Witness
import           ZkFold.Symbolic.Compiler                            (desugarRanges)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

{-| Based on the paper https://eprint.iacr.org/2019/953.pdf -}

data Plonk p (i :: Natural) (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk {
        forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
omega :: ScalarField curve1,
        forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
k1    :: ScalarField curve1,
        forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
k2    :: ScalarField curve1,
        forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript
-> ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l)
ac    :: ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l),
        forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
x     :: ScalarField curve1
    }

fromPlonkup ::
    ( KnownNat i
    , Arithmetic (ScalarField c1)
    , Binary (ScalarField c1)
    , Binary (Rep p)
    ) => Plonkup p i n l c1 c2 ts -> Plonk p i n l c1 c2 ts
fromPlonkup :: forall {k} {k} {k} (i :: Natural) (c1 :: k) (p :: Type -> Type)
       (n :: Natural) (l :: Natural) (c2 :: k) (ts :: k).
(KnownNat i, Arithmetic (ScalarField c1), Binary (ScalarField c1),
 Binary (Rep p)) =>
Plonkup p i n l c1 c2 ts -> Plonk p i n l c1 c2 ts
fromPlonkup Plonkup {ScalarField c1
ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
ac :: ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
x :: ScalarField c1
omega :: forall {k1} {k2} {k3} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k1) (curve2 :: k2)
       (transcript :: k3).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
k1 :: forall {k1} {k2} {k3} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k1) (curve2 :: k2)
       (transcript :: k3).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
k2 :: forall {k1} {k2} {k3} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k1) (curve2 :: k2)
       (transcript :: k3).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
ac :: forall {k1} {k2} {k3} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k1) (curve2 :: k2)
       (transcript :: k3).
Plonkup p i n l curve1 curve2 transcript
-> ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l)
x :: forall {k1} {k2} {k3} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k1) (curve2 :: k2)
       (transcript :: k3).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
..} = Plonk { ac :: ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
ac = ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
-> ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i)) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
desugarRanges ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
ac, ScalarField c1
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
x :: ScalarField c1
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
x :: ScalarField c1
..}

toPlonkup :: Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
toPlonkup :: forall {k} {k1} {k2} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (c1 :: k) (c2 :: k1) (ts :: k2).
Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
toPlonkup Plonk {ScalarField c1
ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
omega :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
k1 :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
k2 :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
ac :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript
-> ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l)
x :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
ac :: ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
x :: ScalarField c1
..} = Plonkup {ScalarField c1
ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
ac :: ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
x :: ScalarField c1
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
ac :: ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
x :: ScalarField c1
..}

instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l, KnownNat i) => Show (Plonk p i n l c1 c2 t) where
    show :: Plonk p i n l c1 c2 t -> String
show Plonk {ScalarField c1
ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
omega :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
k1 :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
k2 :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
ac :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript
-> ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l)
x :: forall {k} {k} {k} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (curve1 :: k) (curve2 :: k)
       (transcript :: k).
Plonk p i n l curve1 curve2 transcript -> ScalarField curve1
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
ac :: ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
x :: ScalarField c1
..} =
        String
"Plonk: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
omega String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
k1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
k2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Vector l (Var (ScalarField c1) (Vector i)) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
-> Vector l (Var (ScalarField c1) (Vector i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
ac) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
-> String
forall a. Show a => a -> String
show ArithmeticCircuit (ScalarField c1) p (Vector i) (Vector l)
ac String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
x

instance ( KnownNat i, Arithmetic (ScalarField c1)
         , Binary (ScalarField c1), Binary (Rep p)
         , Arbitrary (Plonkup p i n l c1 c2 t))
        => Arbitrary (Plonk p i n l c1 c2 t) where
    arbitrary :: Gen (Plonk p i n l c1 c2 t)
arbitrary = Plonkup p i n l c1 c2 t -> Plonk p i n l c1 c2 t
forall {k} {k} {k} (i :: Natural) (c1 :: k) (p :: Type -> Type)
       (n :: Natural) (l :: Natural) (c2 :: k) (ts :: k).
(KnownNat i, Arithmetic (ScalarField c1), Binary (ScalarField c1),
 Binary (Rep p)) =>
Plonkup p i n l c1 c2 ts -> Plonk p i n l c1 c2 ts
fromPlonkup (Plonkup p i n l c1 c2 t -> Plonk p i n l c1 c2 t)
-> Gen (Plonkup p i n l c1 c2 t) -> Gen (Plonk p i n l c1 c2 t)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Plonkup p i n l c1 c2 t)
forall a. Arbitrary a => Gen a
arbitrary

instance forall p i n l c1 c2 (ts :: Type) core .
        ( NonInteractiveProof (Plonkup p i n l c1 c2 ts) core
        , SetupProve (Plonkup p i n l c1 c2 ts) ~ PlonkupProverSetup p i n l c1 c2
        , SetupVerify (Plonkup p i n l c1 c2 ts) ~ PlonkupVerifierSetup p i n l c1 c2
        , Witness (Plonkup p i n l c1 c2 ts) ~ (PlonkupWitnessInput p i c1, PlonkupProverSecret c1)
        , Input (Plonkup p i n l c1 c2 ts) ~ PlonkupInput l c1
        , Proof (Plonkup p i n l c1 c2 ts) ~ PlonkupProof c1
        , KnownNat n
        , Ord (BaseField c1)
        , AdditiveGroup (BaseField c1)
        , Pairing c1 c2
        , Arithmetic (ScalarField c1)
        , ToTranscript ts Word8
        , ToTranscript ts (ScalarField c1)
        , ToTranscript ts (PointCompressed c1)
        , FromTranscript ts (ScalarField c1)
        , CoreFunction c1 core
        ) => NonInteractiveProof (Plonk p i n l c1 c2 ts) core where
    type Transcript (Plonk p i n l c1 c2 ts)  = ts
    type SetupProve (Plonk p i n l c1 c2 ts)  = PlonkupProverSetup p i n l c1 c2
    type SetupVerify (Plonk p i n l c1 c2 ts) = PlonkupVerifierSetup p i n l c1 c2
    type Witness (Plonk p i n l c1 c2 ts)     = (PlonkupWitnessInput p i c1, PlonkupProverSecret c1)
    type Input (Plonk p i n l c1 c2 ts)       = PlonkupInput l c1
    type Proof (Plonk p i n l c1 c2 ts)       = PlonkupProof c1

    setupProve :: Plonk p i n l c1 c2 ts -> SetupProve (Plonk p i n l c1 c2 ts)
    setupProve :: Plonk p i n l c1 c2 ts -> SetupProve (Plonk p i n l c1 c2 ts)
setupProve = forall a (core :: k).
NonInteractiveProof a core =>
a -> SetupProve a
forall {k} a (core :: k).
NonInteractiveProof a core =>
a -> SetupProve a
setupProve @(Plonkup p i n l c1 c2 ts) @core (Plonkup p i n l c1 c2 ts -> PlonkupProverSetup p i n l c1 c2)
-> (Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts)
-> Plonk p i n l c1 c2 ts
-> PlonkupProverSetup p i n l c1 c2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
forall {k} {k1} {k2} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (c1 :: k) (c2 :: k1) (ts :: k2).
Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
toPlonkup

    setupVerify :: Plonk p i n l c1 c2 ts -> SetupVerify (Plonk p i n l c1 c2 ts)
    setupVerify :: Plonk p i n l c1 c2 ts -> SetupVerify (Plonk p i n l c1 c2 ts)
setupVerify = forall a (core :: k).
NonInteractiveProof a core =>
a -> SetupVerify a
forall {k} a (core :: k).
NonInteractiveProof a core =>
a -> SetupVerify a
setupVerify @(Plonkup p i n l c1 c2 ts) @core (Plonkup p i n l c1 c2 ts -> PlonkupVerifierSetup p i n l c1 c2)
-> (Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts)
-> Plonk p i n l c1 c2 ts
-> PlonkupVerifierSetup p i n l c1 c2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
forall {k} {k1} {k2} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (c1 :: k) (c2 :: k1) (ts :: k2).
Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
toPlonkup

    prove :: SetupProve (Plonk p i n l c1 c2 ts) -> Witness (Plonk p i n l c1 c2 ts) -> (Input (Plonk p i n l c1 c2 ts), Proof (Plonk p i n l c1 c2 ts))
    prove :: SetupProve (Plonk p i n l c1 c2 ts)
-> Witness (Plonk p i n l c1 c2 ts)
-> (Input (Plonk p i n l c1 c2 ts), Proof (Plonk p i n l c1 c2 ts))
prove SetupProve (Plonk p i n l c1 c2 ts)
setup Witness (Plonk p i n l c1 c2 ts)
witness =
        let (PlonkupInput l c1
input, PlonkupProof c1
proof, PlonkupProverTestInfo n c1
_) = forall {k} {k2} {k3} (p :: Type -> Type) (i :: Natural)
       (n :: Natural) (l :: Natural) (c1 :: k) (c2 :: k2) ts (core :: k3).
(KnownNat n, Ord (BaseField c1), AdditiveGroup (BaseField c1),
 Arithmetic (ScalarField c1), ToTranscript ts Word8,
 ToTranscript ts (ScalarField c1),
 ToTranscript ts (PointCompressed c1),
 FromTranscript ts (ScalarField c1), CoreFunction c1 core) =>
PlonkupProverSetup p i n l c1 c2
-> (PlonkupWitnessInput p i c1, PlonkupProverSecret c1)
-> (PlonkupInput l c1, PlonkupProof c1, PlonkupProverTestInfo n c1)
forall (p :: Type -> Type) (i :: Natural) (n :: Natural)
       (l :: Natural) (c1 :: k) (c2 :: k) ts (core :: k).
(KnownNat n, Ord (BaseField c1), AdditiveGroup (BaseField c1),
 Arithmetic (ScalarField c1), ToTranscript ts Word8,
 ToTranscript ts (ScalarField c1),
 ToTranscript ts (PointCompressed c1),
 FromTranscript ts (ScalarField c1), CoreFunction c1 core) =>
PlonkupProverSetup p i n l c1 c2
-> (PlonkupWitnessInput p i c1, PlonkupProverSecret c1)
-> (PlonkupInput l c1, PlonkupProof c1, PlonkupProverTestInfo n c1)
plonkProve @p @i @n @l @c1 @c2 @ts @core SetupProve (Plonk p i n l c1 c2 ts)
PlonkupProverSetup p i n l c1 c2
setup (PlonkupWitnessInput p i c1, PlonkupProverSecret c1)
Witness (Plonk p i n l c1 c2 ts)
witness
        in (Input (Plonk p i n l c1 c2 ts)
PlonkupInput l c1
input, PlonkupProof c1
Proof (Plonk p i n l c1 c2 ts)
proof)

    verify :: SetupVerify (Plonk p i n l c1 c2 ts) -> Input (Plonk p i n l c1 c2 ts) -> Proof (Plonk p i n l c1 c2 ts) -> Bool
    verify :: SetupVerify (Plonk p i n l c1 c2 ts)
-> Input (Plonk p i n l c1 c2 ts)
-> Proof (Plonk p i n l c1 c2 ts)
-> Bool
verify = forall {k} {k1} (p :: Type -> Type) (i :: Natural) (n :: Natural)
       (l :: Natural) (c1 :: k) (c2 :: k1) ts.
(KnownNat n, Pairing c1 c2, Ord (BaseField c1),
 AdditiveGroup (BaseField c1), Arithmetic (ScalarField c1),
 ToTranscript ts Word8, ToTranscript ts (ScalarField c1),
 ToTranscript ts (PointCompressed c1),
 FromTranscript ts (ScalarField c1)) =>
PlonkupVerifierSetup p i n l c1 c2
-> PlonkupInput l c1 -> PlonkupProof c1 -> Bool
forall (p :: Type -> Type) (i :: Natural) (n :: Natural)
       (l :: Natural) (c1 :: k) (c2 :: k) ts.
(KnownNat n, Pairing c1 c2, Ord (BaseField c1),
 AdditiveGroup (BaseField c1), Arithmetic (ScalarField c1),
 ToTranscript ts Word8, ToTranscript ts (ScalarField c1),
 ToTranscript ts (PointCompressed c1),
 FromTranscript ts (ScalarField c1)) =>
PlonkupVerifierSetup p i n l c1 c2
-> PlonkupInput l c1 -> PlonkupProof c1 -> Bool
plonkVerify @p @i @n @l @c1 @c2 @ts