{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Plonkup.Internal where

import           Data.Constraint                                     (withDict)
import           Data.Constraint.Nat                                 (plusNat, timesNat)
import           Data.Functor.Classes                                (Show1)
import           Data.Functor.Rep                                    (Rep)
import           Prelude                                             hiding (Num (..), drop, length, sum, take, (!!),
                                                                      (/), (^))
import           Test.QuickCheck                                     (Arbitrary (..))

import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class             (EllipticCurve (..))
import           ZkFold.Base.Algebra.Polynomials.Univariate          (PolyVec)
import           ZkFold.Base.Protocol.Plonkup.Utils
import           ZkFold.Symbolic.Compiler                            ()
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

{-
    NOTE: we need to parametrize the type of transcripts because we use BuiltinByteString on-chain and ByteString off-chain.
    Additionally, we don't want this library to depend on Cardano libraries.
-}

data Plonkup p i (n :: Natural) l curve1 curve2 transcript = Plonkup {
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
omega :: ScalarField curve1,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
k1    :: ScalarField curve1,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
k2    :: ScalarField curve1,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript
-> ArithmeticCircuit (ScalarField curve1) p i l
ac    :: ArithmeticCircuit (ScalarField curve1) p i l,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
x     :: ScalarField curve1
    }

type PlonkupPermutationSize n = 3 * n

-- The maximum degree of the polynomials we need in the protocol is `4 * n + 5`.
type PlonkupPolyExtendedLength n = 4 * n + 6

with4n6 :: forall n {r}. KnownNat n => (KnownNat (4 * n + 6) => r) -> r
with4n6 :: forall (n :: Natural) {r}.
KnownNat n =>
(KnownNat ((4 * n) + 6) => r) -> r
with4n6 KnownNat ((4 * n) + 6) => r
f = ((KnownNat 4, KnownNat n) :- KnownNat (4 * n))
-> (KnownNat (4 * n) => r) -> r
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (forall (n :: Natural) (m :: Natural).
(KnownNat n, KnownNat m) :- KnownNat (n * m)
timesNat @4 @n) (((KnownNat (4 * n), KnownNat 6) :- KnownNat ((4 * n) + 6))
-> (KnownNat ((4 * n) + 6) => r) -> r
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (forall (n :: Natural) (m :: Natural).
(KnownNat n, KnownNat m) :- KnownNat (n + m)
plusNat @(4 * n) @6) r
KnownNat ((4 * n) + 6) => r
f)

type PlonkupPolyExtended n c = PolyVec (ScalarField c) (PlonkupPolyExtendedLength n)

instance (Show (ScalarField c1), Show (Rep i), Show1 l, Ord (Rep i)) => Show (Plonkup p i n l c1 c2 t) where
    show :: Plonkup p i n l c1 c2 t -> String
show Plonkup {ArithmeticCircuit (ScalarField c1) p i l
ScalarField c1
omega :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
k1 :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
k2 :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
ac :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript
-> ArithmeticCircuit (ScalarField curve1) p i l
x :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
Plonkup p i n l curve1 curve2 transcript -> ScalarField curve1
omega :: ScalarField c1
k1 :: ScalarField c1
k2 :: ScalarField c1
ac :: ArithmeticCircuit (ScalarField c1) p i l
x :: ScalarField c1
..} =
        String
"Plonkup: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
omega String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
k1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
k2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ l (Var (ScalarField c1) i) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit (ScalarField c1) p i l
-> l (Var (ScalarField c1) i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit (ScalarField c1) p i l
ac)  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ArithmeticCircuit (ScalarField c1) p i l -> String
forall a. Show a => a -> String
show ArithmeticCircuit (ScalarField c1) p i l
ac String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarField c1 -> String
forall a. Show a => a -> String
show ScalarField c1
x

instance
  ( KnownNat n, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1)
  , Arbitrary (ArithmeticCircuit (ScalarField c1) p i l)
  ) => Arbitrary (Plonkup p i n l c1 c2 t) where
    arbitrary :: Gen (Plonkup p i n l c1 c2 t)
arbitrary = do
        ArithmeticCircuit (ScalarField c1) p i l
ac <- Gen (ArithmeticCircuit (ScalarField c1) p i l)
forall a. Arbitrary a => Gen a
arbitrary
        let (ScalarField c1
omega, ScalarField c1
k1, ScalarField c1
k2) = Natural -> (ScalarField c1, ScalarField c1, ScalarField c1)
forall a. (Eq a, FiniteField a) => Natural -> (a, a, a)
getParams (forall (n :: Natural). KnownNat n => Natural
value @n)
        ScalarField c1
-> ScalarField c1
-> ScalarField c1
-> ArithmeticCircuit (ScalarField c1) p i l
-> ScalarField c1
-> Plonkup p i n l c1 c2 t
forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) curve1 (curve2 :: k)
       (transcript :: k).
ScalarField curve1
-> ScalarField curve1
-> ScalarField curve1
-> ArithmeticCircuit (ScalarField curve1) p i l
-> ScalarField curve1
-> Plonkup p i n l curve1 curve2 transcript
Plonkup ScalarField c1
omega ScalarField c1
k1 ScalarField c1
k2 ArithmeticCircuit (ScalarField c1) p i l
ac (ScalarField c1 -> Plonkup p i n l c1 c2 t)
-> Gen (ScalarField c1) -> Gen (Plonkup p i n l c1 c2 t)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (ScalarField c1)
forall a. Arbitrary a => Gen a
arbitrary