{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Base.Protocol.IVC.SpecialSound where

import           Data.Functor.Rep                      (Representable (..))
import           Data.Map.Strict                       (elems)
import           GHC.Generics                          ((:*:) (..))
import           Prelude                               (undefined, ($))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector               (Vector)
import qualified ZkFold.Base.Protocol.IVC.AlgebraicMap as AM
import           ZkFold.Base.Protocol.IVC.Predicate    (Predicate (..))
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Compiler

{-- | Section 3.1

The protocol Πsps has 3 essential parameters k, d, l ∈ N, meaning that Πsps is a (2k − 1)-
move protocol with verifier degree d and output length l (i.e. the verifier checks l degree
d algebraic equations). In each round i (1 ≤ i ≤ k), the prover Psps(pi, w, [mj , rj], j=1 to i-1)
generates the next message mi on input the public input pi, the witness w, and the current
transcript [mj , rj], j=1 to i-1, and sends mi to the verifier; the verifier replies with a random
challenge ri ∈ F. After the final message mk, the verifier computes the algebraic map Vsps
and checks that the output is a zero vector of length l.

--}

data SpecialSoundProtocol k i p m o f = SpecialSoundProtocol
  {
    forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
SpecialSoundProtocol k i p m o f -> i f -> p f -> i f
input ::
         i f                            -- ^ previous public input
      -> p f                            -- ^ witness
      -> i f                            -- ^ public input

  , forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
SpecialSoundProtocol k i p m o f -> i f -> p f -> f -> Natural -> m
prover ::
         i f                            -- ^ previous public input
      -> p f                            -- ^ witness
      -> f                              -- ^ current random challenge
      -> Natural                        -- ^ round number (starting from 1)
      -> m                              -- ^ prover message

  , forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
SpecialSoundProtocol k i p m o f
-> i f -> Vector k m -> Vector (k - 1) f -> o
verifier ::
         i f                            -- ^ public input
      -> Vector k m                     -- ^ prover messages
      -> Vector (k-1) f                 -- ^ random challenges
      -> o                              -- ^ verifier output
  }

specialSoundProtocol :: forall d a i p .
    ( KnownNat (d+1)
    , Arithmetic a
    , Representable i
    , Representable p
    ) => Predicate a i p -> SpecialSoundProtocol 1 i p [a] [a] a
specialSoundProtocol :: forall (d :: Natural) a (i :: Type -> Type) (p :: Type -> Type).
(KnownNat (d + 1), Arithmetic a, Representable i,
 Representable p) =>
Predicate a i p -> SpecialSoundProtocol 1 i p [a] [a] a
specialSoundProtocol phi :: Predicate a i p
phi@Predicate {PredicateCircuit a i p
i a -> p a -> i a
predicateEval :: i a -> p a -> i a
predicateCircuit :: PredicateCircuit a i p
predicateEval :: forall a (i :: Type -> Type) (p :: Type -> Type).
Predicate a i p -> i a -> p a -> i a
predicateCircuit :: forall a (i :: Type -> Type) (p :: Type -> Type).
Predicate a i p -> PredicateCircuit a i p
..} =
  let
      prover :: i a -> p a -> a -> Natural -> [a]
prover i a
pi0 p a
w a
_ Natural
_ = Map ByteString a -> [a]
forall k a. Map k a -> [a]
elems (Map ByteString a -> [a]) -> Map ByteString a -> [a]
forall a b. (a -> b) -> a -> b
$ PredicateCircuit a i p -> (:*:) i p a -> i a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator PredicateCircuit a i p
predicateCircuit (i a
pi0 i a -> p a -> (:*:) i p a
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: p a
w) (i a -> p a -> i a
predicateEval i a
pi0 p a
w)

      verifier :: i a -> Vector 1 [a] -> Vector 0 a -> [a]
verifier i a
pi Vector 1 [a]
pm Vector 0 a
ts = forall (d :: Natural) (k :: Natural) a (i :: Type -> Type)
       (p :: Type -> Type) f.
(KnownNat (d + 1), Representable i, Ring f, Scale a f) =>
Predicate a i p
-> i f -> Vector k [f] -> Vector (k - 1) f -> f -> [f]
AM.algebraicMap @d Predicate a i p
phi i a
pi Vector 1 [a]
pm Vector 0 a
Vector (1 - 1) a
ts a
forall a. MultiplicativeMonoid a => a
one
  in
      (i a -> p a -> i a)
-> (i a -> p a -> a -> Natural -> [a])
-> (i a -> Vector 1 [a] -> Vector (1 - 1) a -> [a])
-> SpecialSoundProtocol 1 i p [a] [a] a
forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
(i f -> p f -> i f)
-> (i f -> p f -> f -> Natural -> m)
-> (i f -> Vector k m -> Vector (k - 1) f -> o)
-> SpecialSoundProtocol k i p m o f
SpecialSoundProtocol i a -> p a -> i a
predicateEval i a -> p a -> a -> Natural -> [a]
prover i a -> Vector 1 [a] -> Vector 0 a -> [a]
i a -> Vector 1 [a] -> Vector (1 - 1) a -> [a]
verifier

specialSoundProtocol' :: forall d a i p f .
    ( KnownNat (d+1)
    , Representable i
    , Ring f
    , Scale a f
    ) => Predicate a i p -> SpecialSoundProtocol 1 i p [f] [f] f
specialSoundProtocol' :: forall (d :: Natural) a (i :: Type -> Type) (p :: Type -> Type) f.
(KnownNat (d + 1), Representable i, Ring f, Scale a f) =>
Predicate a i p -> SpecialSoundProtocol 1 i p [f] [f] f
specialSoundProtocol' Predicate a i p
phi =
  let
      verifier :: i f -> Vector 1 [f] -> Vector 0 f -> [f]
verifier i f
pi Vector 1 [f]
pm Vector 0 f
ts = forall (d :: Natural) (k :: Natural) a (i :: Type -> Type)
       (p :: Type -> Type) f.
(KnownNat (d + 1), Representable i, Ring f, Scale a f) =>
Predicate a i p
-> i f -> Vector k [f] -> Vector (k - 1) f -> f -> [f]
AM.algebraicMap @d Predicate a i p
phi i f
pi Vector 1 [f]
pm Vector 0 f
Vector (1 - 1) f
ts f
forall a. MultiplicativeMonoid a => a
one
  in
      (i f -> p f -> i f)
-> (i f -> p f -> f -> Natural -> [f])
-> (i f -> Vector 1 [f] -> Vector (1 - 1) f -> [f])
-> SpecialSoundProtocol 1 i p [f] [f] f
forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
(i f -> p f -> i f)
-> (i f -> p f -> f -> Natural -> m)
-> (i f -> Vector k m -> Vector (k - 1) f -> o)
-> SpecialSoundProtocol k i p m o f
SpecialSoundProtocol i f -> p f -> i f
forall a. HasCallStack => a
undefined i f -> p f -> f -> Natural -> [f]
forall a. HasCallStack => a
undefined i f -> Vector 1 [f] -> Vector 0 f -> [f]
i f -> Vector 1 [f] -> Vector (1 - 1) f -> [f]
verifier