{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Protostar.AlgebraicMap where

import           Data.ByteString                                       (ByteString)
import           Data.Functor.Rep                                      (Representable (..))
import           Data.List                                             (foldl')
import           Data.Map.Strict                                       (Map, keys)
import qualified Data.Map.Strict                                       as M
import           Prelude                                               (fmap, zip, ($), (.), (<$>))
import qualified Prelude                                               as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Algebra.Polynomials.Multivariate          as PM
import           ZkFold.Base.Algebra.Polynomials.Multivariate
import qualified ZkFold.Base.Data.Vector                               as V
import           ZkFold.Base.Data.Vector                               (Vector)
import           ZkFold.Base.Protocol.Protostar.ArithmetizableFunction (ArithmetizableFunction (..))
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Compiler
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import           ZkFold.Symbolic.Data.Eq

-- | Algebraic map of @a@.
-- It calculates a system of equations defining @a@ in some way.
-- The inputs are polymorphic in a ring element @f@.
-- The main application is to define the verifier's algebraic map in the NARK protocol.
--
class (Ring f) => AlgebraicMap f i (d :: Natural) a where
    -- | the algebraic map Vsps computed by the NARK verifier.
    algebraicMap :: a
        -> i f            -- ^ public input
        -> Vector k [f]   -- ^ NARK proof witness (the list of prover messages)
        -> Vector (k-1) f -- ^ Verifier random challenges
        -> f              -- ^ Slack variable for padding
        -> [f]

instance
  ( Ring f
  , Representable i
  , KnownNat (d + 1)
  , Arithmetic a
  , Scale a f
  ) => AlgebraicMap f i d (ArithmetizableFunction a i p) where
    -- We can use the polynomial system from the circuit as a base for Vsps.
    --
    algebraicMap :: forall (k :: Natural).
ArithmetizableFunction a i p
-> i f -> Vector k [f] -> Vector (k - 1) f -> f -> [f]
algebraicMap ArithmetizableFunction {ArithmeticCircuit a (i :*: p) i U1
i a -> p a -> i a
afEval :: i a -> p a -> i a
afCircuit :: ArithmeticCircuit a (i :*: p) i U1
afEval :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> i a -> p a -> i a
afCircuit :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> ArithmeticCircuit a (i :*: p) i U1
..} i f
pi Vector k [f]
pm Vector (k - 1) f
_ f
pad = f -> Vector (d + 1) [f] -> [f]
forall f (n :: Natural).
(MultiplicativeMonoid f, AdditiveMonoid f, KnownNat n) =>
f -> Vector n [f] -> [f]
padDecomposition f
pad Vector (d + 1) [f]
f_sps_uni
        where
            sys :: [PM.Poly a (SysVar i) Natural]
            sys :: [Poly a (SysVar i) Natural]
sys = Map ByteString (Poly a (SysVar i) Natural)
-> [Poly a (SysVar i) Natural]
forall k a. Map k a -> [a]
M.elems (ArithmeticCircuit a (i :*: p) i U1
-> Map ByteString (Poly a (SysVar i) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a (i :*: p) i U1
afCircuit)

            witness :: Map ByteString f
            witness :: Map ByteString f
witness = [(ByteString, f)] -> Map ByteString f
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ByteString, f)] -> Map ByteString f)
-> [(ByteString, f)] -> Map ByteString f
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [f] -> [(ByteString, f)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Map ByteString (WitnessF a (WitVar (i :*: p) i)) -> [ByteString]
forall k a. Map k a -> [k]
keys (Map ByteString (WitnessF a (WitVar (i :*: p) i)) -> [ByteString])
-> Map ByteString (WitnessF a (WitVar (i :*: p) i)) -> [ByteString]
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a (i :*: p) i U1
-> Map ByteString (WitnessF a (WitVar (i :*: p) i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> Map ByteString (WitnessF a (WitVar p i))
acWitness ArithmeticCircuit a (i :*: p) i U1
afCircuit) (Vector k [f] -> [f]
forall (size :: Natural) a. Vector size a -> a
V.head Vector k [f]
pm)

            varMap :: SysVar i -> f
            varMap :: SysVar i -> f
varMap (InVar Rep i
inV)   = i f -> Rep i -> f
forall a. i a -> Rep i -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index i f
pi Rep i
inV
            varMap (NewVar ByteString
newV) = f -> ByteString -> Map ByteString f -> f
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault f
forall a. AdditiveMonoid a => a
zero ByteString
newV Map ByteString f
witness

            f_sps :: Vector (d+1) [PM.Poly a (SysVar i) Natural]
            f_sps :: Vector (d + 1) [Poly a (SysVar i) Natural]
f_sps = forall f (d :: Natural) v.
KnownNat (d + 1) =>
[Poly f v Natural] -> Vector (d + 1) [Poly f v Natural]
degreeDecomposition @_ @d ([Poly a (SysVar i) Natural]
 -> Vector (d + 1) [Poly a (SysVar i) Natural])
-> [Poly a (SysVar i) Natural]
-> Vector (d + 1) [Poly a (SysVar i) Natural]
forall a b. (a -> b) -> a -> b
$ [Poly a (SysVar i) Natural]
sys

            f_sps_uni :: Vector (d+1) [f]
            f_sps_uni :: Vector (d + 1) [f]
f_sps_uni = (Poly a (SysVar i) Natural -> f)
-> [Poly a (SysVar i) Natural] -> [f]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (((SysVar i -> f) -> Mono (SysVar i) Natural -> f)
-> (SysVar i -> f) -> Poly a (SysVar i) Natural -> f
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
PM.evalPolynomial (SysVar i -> f) -> Mono (SysVar i) Natural -> f
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
PM.evalMonomial SysVar i -> f
varMap) ([Poly a (SysVar i) Natural] -> [f])
-> Vector (d + 1) [Poly a (SysVar i) Natural] -> Vector (d + 1) [f]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (d + 1) [Poly a (SysVar i) Natural]
f_sps


padDecomposition :: forall f n .
    ( MultiplicativeMonoid f
    , AdditiveMonoid f
    , KnownNat n
    ) => f -> V.Vector n [f] -> [f]
padDecomposition :: forall f (n :: Natural).
(MultiplicativeMonoid f, AdditiveMonoid f, KnownNat n) =>
f -> Vector n [f] -> [f]
padDecomposition f
pad = ([f] -> [f] -> [f]) -> [f] -> Vector n [f] -> [f]
forall b a. (b -> a -> b) -> b -> Vector n a -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((f -> f -> f) -> [f] -> [f] -> [f]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
(+)) (f -> [f]
forall a. a -> [a]
P.repeat f
forall a. AdditiveMonoid a => a
zero) (Vector n [f] -> [f])
-> (Vector n [f] -> Vector n [f]) -> Vector n [f] -> [f]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> [f] -> [f]) -> Vector n [f] -> Vector n [f]
forall (n :: Natural) a b.
KnownNat n =>
(Natural -> a -> b) -> Vector n a -> Vector n b
V.mapWithIx (\Natural
j [f]
p -> ((f
pad f -> Natural -> f
forall a b. Exponent a b => a -> b -> a
^ (Natural
d Natural -> Natural -> Natural
-! Natural
j)) f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
* ) (f -> f) -> [f] -> [f]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [f]
p)
    where
        d :: Natural
d = forall (n :: Natural). KnownNat n => Natural
value @n Natural -> Natural -> Natural
-! Natural
1

-- | Decomposes an algebraic map into homogenous degree-j maps for j from 0 to @d@
--
degreeDecomposition :: forall f d v . KnownNat (d + 1) => [Poly f v Natural] -> V.Vector (d + 1) [Poly f v Natural]
degreeDecomposition :: forall f (d :: Natural) v.
KnownNat (d + 1) =>
[Poly f v Natural] -> Vector (d + 1) [Poly f v Natural]
degreeDecomposition [Poly f v Natural]
lmap = (Rep (Vector (d + 1)) -> [Poly f v Natural])
-> Vector (d + 1) [Poly f v Natural]
forall a. (Rep (Vector (d + 1)) -> a) -> Vector (d + 1) a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (Natural -> [Poly f v Natural]
degree_j (Natural -> [Poly f v Natural])
-> (Zp (d + 1) -> Natural) -> Zp (d + 1) -> [Poly f v Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Zp (d + 1) -> Natural
Zp (d + 1) -> Const (Zp (d + 1))
forall a. ToConstant a => a -> Const a
toConstant)
    where
        degree_j :: Natural -> [Poly f v Natural]
        degree_j :: Natural -> [Poly f v Natural]
degree_j Natural
j = (Poly f v Natural -> Poly f v Natural)
-> [Poly f v Natural] -> [Poly f v Natural]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
P.fmap (Natural -> Poly f v Natural -> Poly f v Natural
leaveDeg Natural
j) [Poly f v Natural]
lmap

        leaveDeg :: Natural -> PM.Poly f v Natural -> PM.Poly f v Natural
        leaveDeg :: Natural -> Poly f v Natural -> Poly f v Natural
leaveDeg Natural
j (PM.P [(f, Mono v Natural)]
monomials) = [(f, Mono v Natural)] -> Poly f v Natural
forall c i j. [(c, Mono i j)] -> Poly c i j
PM.P ([(f, Mono v Natural)] -> Poly f v Natural)
-> [(f, Mono v Natural)] -> Poly f v Natural
forall a b. (a -> b) -> a -> b
$ ((f, Mono v Natural) -> Bool)
-> [(f, Mono v Natural)] -> [(f, Mono v Natural)]
forall a. (a -> Bool) -> [a] -> [a]
P.filter (\(f
_, Mono v Natural
m) -> Mono v Natural -> Natural
forall v. Mono v Natural -> Natural
deg Mono v Natural
m Natural -> Natural -> Bool
forall b a. Eq b a => a -> a -> b
== Natural
j) [(f, Mono v Natural)]
monomials

deg :: PM.Mono v Natural -> Natural
deg :: forall v. Mono v Natural -> Natural
deg (PM.M Map v Natural
m) = Map v Natural -> Natural
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum Map v Natural
m