module ZkFold.Base.Protocol.ARK.Protostar.Gate where

import           Data.Zip                                        (zipWith)
import           Numeric.Natural                                 (Natural)
import           Prelude                                         hiding (Num (..), zipWith, (!!), (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field                 (Zp)
import           ZkFold.Base.Algebra.Basic.Number                (KnownNat)
import           ZkFold.Base.Algebra.Polynomials.Multivariate    (Polynomial', evalPolynomial, evalVectorM, subs, var)
import           ZkFold.Base.Data.Matrix                         (Matrix (..), outer, sum1, transpose)
import           ZkFold.Base.Data.Vector                         (Vector)
import           ZkFold.Base.Protocol.ARK.Protostar.Internal     (PolynomialProtostar)
import           ZkFold.Base.Protocol.ARK.Protostar.SpecialSound (SpecialSoundProtocol (..), SpecialSoundTranscript)
import           ZkFold.Symbolic.Compiler.Arithmetizable         (Arithmetic)

data ProtostarGate (m :: Natural) (n :: Natural) (c :: Natural) (d :: Natural)

instance (Arithmetic f, KnownNat m, KnownNat n, KnownNat c) => SpecialSoundProtocol f (ProtostarGate m n c d) where
    type Witness f (ProtostarGate m n c d)       = Vector n (Vector c f)
    -- ^ [(a_j, w_j)]_{j=1}^n where [w_j]_{j=1}^n is from the paper together and [a_j]_{j=1}^n are their absolute indices
    type Input f (ProtostarGate m n c d)         = (Matrix m n f, Vector m (PolynomialProtostar f c d))
    -- ^ [s_{i, j}] and [G_i]_{i=1}^m in the paper
    type ProverMessage t (ProtostarGate m n c d)  = Vector n (Vector c t)
    -- ^ same as Witness
    type VerifierMessage t (ProtostarGate m n c d) = ()

    type Dimension (ProtostarGate m n c d)        = n
    type Degree (ProtostarGate m n c d)           = d

    rounds :: ProtostarGate m n c d -> Natural
    rounds :: ProtostarGate m n c d -> Natural
rounds ProtostarGate m n c d
_ = Natural
1

    prover :: ProtostarGate m n c d
          -> Witness f (ProtostarGate m n c d)
          -> Input f (ProtostarGate m n c d)
          -> SpecialSoundTranscript f (ProtostarGate m n c d)
          -> ProverMessage f (ProtostarGate m n c d)
    prover :: ProtostarGate m n c d
-> Witness f (ProtostarGate m n c d)
-> Input f (ProtostarGate m n c d)
-> SpecialSoundTranscript f (ProtostarGate m n c d)
-> ProverMessage f (ProtostarGate m n c d)
prover ProtostarGate m n c d
_ Witness f (ProtostarGate m n c d)
w Input f (ProtostarGate m n c d)
_ SpecialSoundTranscript f (ProtostarGate m n c d)
_ = Witness f (ProtostarGate m n c d)
ProverMessage f (ProtostarGate m n c d)
w

    verifier' :: ProtostarGate m n c d
              -> Input f (ProtostarGate m n c d)
              -> SpecialSoundTranscript Natural (ProtostarGate m n c d)
              -> Vector (Dimension (ProtostarGate m n c d)) (Polynomial' f)
    verifier' :: ProtostarGate m n c d
-> Input f (ProtostarGate m n c d)
-> SpecialSoundTranscript Natural (ProtostarGate m n c d)
-> Vector (Dimension (ProtostarGate m n c d)) (Polynomial' f)
verifier' ProtostarGate m n c d
_ (Matrix m n f
s, Vector
  m
  (P f
     (Zp c)
     Bool
     (Vector d (Zp c, Bool))
     [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))])
g) [(ProverMessage Natural (ProtostarGate m n c d)
w, VerifierMessage Natural (ProtostarGate m n c d)
_)] =
      let w' :: Vector n (Zp c -> Polynomial' f)
w' = (Vector c Natural -> Zp c -> Polynomial' f)
-> Vector n (Vector c Natural) -> Vector n (Zp c -> Polynomial' f)
forall a b. (a -> b) -> Vector n a -> Vector n b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Natural -> Polynomial' f
forall c i j.
Polynomial c i j =>
i -> P c i j (Map i j) [(c, M i j (Map i j))]
var (Natural -> Polynomial' f)
-> (Zp c -> Natural) -> Zp c -> Polynomial' f
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Zp c -> Natural) -> Zp c -> Polynomial' f)
-> (Vector c Natural -> Zp c -> Natural)
-> Vector c Natural
-> Zp c
-> Polynomial' f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector c Natural -> Zp c -> Natural
forall s i b. Substitution s i b => s -> i -> b
subs) Vector n (Vector c Natural)
ProverMessage Natural (ProtostarGate m n c d)
w :: Vector n (Zp c -> Polynomial' f)
          z :: Matrix m n (Polynomial' f)
z  = Matrix n m (Polynomial' f) -> Matrix m n (Polynomial' f)
forall (m :: Natural) (n :: Natural) a.
(KnownNat m, KnownNat n) =>
Matrix m n a -> Matrix n m a
transpose (Matrix n m (Polynomial' f) -> Matrix m n (Polynomial' f))
-> Matrix n m (Polynomial' f) -> Matrix m n (Polynomial' f)
forall a b. (a -> b) -> a -> b
$ ((Zp c -> Polynomial' f)
 -> P f
      (Zp c)
      Bool
      (Vector d (Zp c, Bool))
      [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))]
 -> Polynomial' f)
-> Vector n (Zp c -> Polynomial' f)
-> Vector
     m
     (P f
        (Zp c)
        Bool
        (Vector d (Zp c, Bool))
        [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))])
-> Matrix n m (Polynomial' f)
forall (m :: Natural) (n :: Natural) a b c.
(a -> b -> c) -> Vector m a -> Vector n b -> Matrix m n c
outer (((Zp c -> Polynomial' f)
 -> M (Zp c) Bool (Vector d (Zp c, Bool)) -> Polynomial' f)
-> (Zp c -> Polynomial' f)
-> P f
     (Zp c)
     Bool
     (Vector d (Zp c, Bool))
     [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))]
-> Polynomial' f
forall {k1} c i (j :: k1) b m.
Algebra c b =>
((i -> b) -> M i j m -> b)
-> (i -> b) -> P c i j m [(c, M i j m)] -> b
evalPolynomial (Zp c -> Polynomial' f)
-> M (Zp c) Bool (Vector d (Zp c, Bool)) -> Polynomial' f
forall i j b (d :: Natural).
(Monomial i j, MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> M i j (Vector d (i, Bool)) -> b
evalVectorM) Vector n (Zp c -> Polynomial' f)
w' Vector
  m
  (P f
     (Zp c)
     Bool
     (Vector d (Zp c, Bool))
     [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))])
g
      in Matrix m (Dimension (ProtostarGate m n c d)) (Polynomial' f)
-> Vector (Dimension (ProtostarGate m n c d)) (Polynomial' f)
forall a (m :: Natural) (n :: Natural).
Semiring a =>
Matrix m n a -> Vector n a
sum1 (Matrix m (Dimension (ProtostarGate m n c d)) (Polynomial' f)
 -> Vector (Dimension (ProtostarGate m n c d)) (Polynomial' f))
-> Matrix m (Dimension (ProtostarGate m n c d)) (Polynomial' f)
-> Vector (Dimension (ProtostarGate m n c d)) (Polynomial' f)
forall a b. (a -> b) -> a -> b
$ (f -> Polynomial' f -> Polynomial' f)
-> Matrix m n f
-> Matrix m n (Polynomial' f)
-> Matrix m n (Polynomial' f)
forall a b c.
(a -> b -> c) -> Matrix m n a -> Matrix m n b -> Matrix m n c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith f -> Polynomial' f -> Polynomial' f
forall b a. Scale b a => b -> a -> a
scale Matrix m n f
s Matrix m n (Polynomial' f)
z
    verifier' ProtostarGate m n c d
_ Input f (ProtostarGate m n c d)
_ SpecialSoundTranscript Natural (ProtostarGate m n c d)
_ = [Char] -> Vector n (Polynomial' f)
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid transcript"

    verifier :: ProtostarGate m n c d
             -> Input f (ProtostarGate m n c d)
             -> SpecialSoundTranscript f (ProtostarGate m n c d)
             -> Bool
    verifier :: ProtostarGate m n c d
-> Input f (ProtostarGate m n c d)
-> SpecialSoundTranscript f (ProtostarGate m n c d)
-> Bool
verifier ProtostarGate m n c d
_ (Matrix m n f
s, Vector
  m
  (P f
     (Zp c)
     Bool
     (Vector d (Zp c, Bool))
     [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))])
g) [(ProverMessage f (ProtostarGate m n c d)
w, VerifierMessage f (ProtostarGate m n c d)
_)] =
      let w' :: Vector n (Zp c -> f)
w' = (Vector c f -> Zp c -> f)
-> Vector n (Vector c f) -> Vector n (Zp c -> f)
forall a b. (a -> b) -> Vector n a -> Vector n b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Vector c f -> Zp c -> f
forall s i b. Substitution s i b => s -> i -> b
subs Vector n (Vector c f)
ProverMessage f (ProtostarGate m n c d)
w :: Vector n (Zp c -> f)
          z :: Matrix m n f
z  = Matrix n m f -> Matrix m n f
forall (m :: Natural) (n :: Natural) a.
(KnownNat m, KnownNat n) =>
Matrix m n a -> Matrix n m a
transpose (Matrix n m f -> Matrix m n f) -> Matrix n m f -> Matrix m n f
forall a b. (a -> b) -> a -> b
$ ((Zp c -> f)
 -> P f
      (Zp c)
      Bool
      (Vector d (Zp c, Bool))
      [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))]
 -> f)
-> Vector n (Zp c -> f)
-> Vector
     m
     (P f
        (Zp c)
        Bool
        (Vector d (Zp c, Bool))
        [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))])
-> Matrix n m f
forall (m :: Natural) (n :: Natural) a b c.
(a -> b -> c) -> Vector m a -> Vector n b -> Matrix m n c
outer (((Zp c -> f) -> M (Zp c) Bool (Vector d (Zp c, Bool)) -> f)
-> (Zp c -> f)
-> P f
     (Zp c)
     Bool
     (Vector d (Zp c, Bool))
     [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))]
-> f
forall {k1} c i (j :: k1) b m.
Algebra c b =>
((i -> b) -> M i j m -> b)
-> (i -> b) -> P c i j m [(c, M i j m)] -> b
evalPolynomial (Zp c -> f) -> M (Zp c) Bool (Vector d (Zp c, Bool)) -> f
forall i j b (d :: Natural).
(Monomial i j, MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> M i j (Vector d (i, Bool)) -> b
evalVectorM) Vector n (Zp c -> f)
w' Vector
  m
  (P f
     (Zp c)
     Bool
     (Vector d (Zp c, Bool))
     [(f, M (Zp c) Bool (Vector d (Zp c, Bool)))])
g
      in (f -> Bool) -> Vector n f -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
forall a. AdditiveMonoid a => a
zero) (Vector n f -> Bool) -> Vector n f -> Bool
forall a b. (a -> b) -> a -> b
$ Matrix m n f -> Vector n f
forall a (m :: Natural) (n :: Natural).
Semiring a =>
Matrix m n a -> Vector n a
sum1 (Matrix m n f -> Vector n f) -> Matrix m n f -> Vector n f
forall a b. (a -> b) -> a -> b
$ (f -> f -> f) -> Matrix m n f -> Matrix m n f -> Matrix m n f
forall a b c.
(a -> b -> c) -> Matrix m n a -> Matrix m n b -> Matrix m n c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Matrix m n f
s Matrix m n f
z
    verifier ProtostarGate m n c d
_ Input f (ProtostarGate m n c d)
_ SpecialSoundTranscript f (ProtostarGate m n c d)
_ = [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid transcript"