{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Base.Protocol.IVC.AlgebraicMap (algebraicMap) where

import           Data.ByteString                                     (ByteString)
import           Data.Functor.Rep                                    (Representable (..))
import           Data.List                                           (foldl')
import           Data.Map.Strict                                     (Map, keys)
import qualified Data.Map.Strict                                     as M
import           Prelude                                             (fmap, zip, ($), (.), (<$>))
import qualified Prelude                                             as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Algebra.Polynomials.Multivariate        as PM
import           ZkFold.Base.Algebra.Polynomials.Multivariate
import qualified ZkFold.Base.Data.Vector                             as V
import           ZkFold.Base.Data.Vector                             (Vector)
import           ZkFold.Base.Protocol.IVC.Predicate                  (Predicate (..))
import           ZkFold.Symbolic.Compiler
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import           ZkFold.Symbolic.Data.Eq

-- | Algebraic map of @a@.
-- It calculates a system of equations defining @a@ in some way.
-- The inputs are polymorphic in a ring element @f@.
-- The main application is to define the verifier's algebraic map in the NARK protocol.
--
algebraicMap :: forall d k a i p f .
    ( KnownNat (d+1)
    , Representable i
    , Ring f
    , Scale a f
    )
    => Predicate a i p
    -> i f
    -> Vector k [f]
    -> Vector (k-1) f
    -> f
    -> [f]
algebraicMap :: forall (d :: Natural) (k :: Natural) a (i :: Type -> Type)
       (p :: Type -> Type) f.
(KnownNat (d + 1), Representable i, Ring f, Scale a f) =>
Predicate a i p
-> i f -> Vector k [f] -> Vector (k - 1) f -> f -> [f]
algebraicMap Predicate {PredicateCircuit a i p
i a -> p a -> i a
predicateEval :: i a -> p a -> i a
predicateCircuit :: PredicateCircuit a i p
predicateEval :: forall a (i :: Type -> Type) (p :: Type -> Type).
Predicate a i p -> i a -> p a -> i a
predicateCircuit :: forall a (i :: Type -> Type) (p :: Type -> Type).
Predicate a i p -> PredicateCircuit a i p
..} i f
pi Vector k [f]
pm Vector (k - 1) f
_ f
pad = f -> Vector (d + 1) [f] -> [f]
forall (d :: Natural) f.
(MultiplicativeMonoid f, AdditiveMonoid f, KnownNat (d + 1)) =>
f -> Vector (d + 1) [f] -> [f]
padDecomposition f
pad Vector (d + 1) [f]
f_sps_uni
    where
        sys :: [PM.Poly a (SysVar i) Natural]
        sys :: [Poly a (SysVar i) Natural]
sys = Map ByteString (Poly a (SysVar i) Natural)
-> [Poly a (SysVar i) Natural]
forall k a. Map k a -> [a]
M.elems (PredicateCircuit a i p
-> Map ByteString (Poly a (SysVar i) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem PredicateCircuit a i p
predicateCircuit)

        witness :: Map ByteString f
        witness :: Map ByteString f
witness = [(ByteString, f)] -> Map ByteString f
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ByteString, f)] -> Map ByteString f)
-> [(ByteString, f)] -> Map ByteString f
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [f] -> [(ByteString, f)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Map ByteString (CircuitWitness a (i :*: p) i) -> [ByteString]
forall k a. Map k a -> [k]
keys (Map ByteString (CircuitWitness a (i :*: p) i) -> [ByteString])
-> Map ByteString (CircuitWitness a (i :*: p) i) -> [ByteString]
forall a b. (a -> b) -> a -> b
$ PredicateCircuit a i p
-> Map ByteString (CircuitWitness a (i :*: p) i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
acWitness PredicateCircuit a i p
predicateCircuit) (Vector k [f] -> [f]
forall (size :: Natural) a. Vector size a -> a
V.head Vector k [f]
pm)

        varMap :: SysVar i -> f
        varMap :: SysVar i -> f
varMap (InVar Rep i
inV)   = i f -> Rep i -> f
forall a. i a -> Rep i -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index i f
pi Rep i
inV
        varMap (NewVar ByteString
newV) = f -> ByteString -> Map ByteString f -> f
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault f
forall a. AdditiveMonoid a => a
zero ByteString
newV Map ByteString f
witness

        f_sps :: Vector (d+1) [PM.Poly a (SysVar i) Natural]
        f_sps :: Vector (d + 1) [Poly a (SysVar i) Natural]
f_sps = forall (d :: Natural) f v.
KnownNat (d + 1) =>
[Poly f v Natural] -> Vector (d + 1) [Poly f v Natural]
degreeDecomposition @d ([Poly a (SysVar i) Natural]
 -> Vector (d + 1) [Poly a (SysVar i) Natural])
-> [Poly a (SysVar i) Natural]
-> Vector (d + 1) [Poly a (SysVar i) Natural]
forall a b. (a -> b) -> a -> b
$ [Poly a (SysVar i) Natural]
sys

        f_sps_uni :: Vector (d+1) [f]
        f_sps_uni :: Vector (d + 1) [f]
f_sps_uni = (Poly a (SysVar i) Natural -> f)
-> [Poly a (SysVar i) Natural] -> [f]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (((SysVar i -> f) -> Mono (SysVar i) Natural -> f)
-> (SysVar i -> f) -> Poly a (SysVar i) Natural -> f
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
PM.evalPolynomial (SysVar i -> f) -> Mono (SysVar i) Natural -> f
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
PM.evalMonomial SysVar i -> f
varMap) ([Poly a (SysVar i) Natural] -> [f])
-> Vector (d + 1) [Poly a (SysVar i) Natural] -> Vector (d + 1) [f]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (d + 1) [Poly a (SysVar i) Natural]
f_sps

padDecomposition :: forall d f .
    ( MultiplicativeMonoid f
    , AdditiveMonoid f
    , KnownNat (d+1)
    ) => f -> V.Vector (d+1) [f] -> [f]
padDecomposition :: forall (d :: Natural) f.
(MultiplicativeMonoid f, AdditiveMonoid f, KnownNat (d + 1)) =>
f -> Vector (d + 1) [f] -> [f]
padDecomposition f
pad = ([f] -> [f] -> [f]) -> [f] -> Vector (d + 1) [f] -> [f]
forall b a. (b -> a -> b) -> b -> Vector (d + 1) a -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((f -> f -> f) -> [f] -> [f] -> [f]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
(+)) (f -> [f]
forall a. a -> [a]
P.repeat f
forall a. AdditiveMonoid a => a
zero) (Vector (d + 1) [f] -> [f])
-> (Vector (d + 1) [f] -> Vector (d + 1) [f])
-> Vector (d + 1) [f]
-> [f]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> [f] -> [f]) -> Vector (d + 1) [f] -> Vector (d + 1) [f]
forall (n :: Natural) a b.
KnownNat n =>
(Natural -> a -> b) -> Vector n a -> Vector n b
V.mapWithIx (\Natural
j [f]
p -> ((f
pad f -> Natural -> f
forall a b. Exponent a b => a -> b -> a
^ (Natural
d Natural -> Natural -> Natural
-! Natural
j)) f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
* ) (f -> f) -> [f] -> [f]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [f]
p)
    where
        d :: Natural
d = forall (n :: Natural). KnownNat n => Natural
value @(d+1) Natural -> Natural -> Natural
-! Natural
1

-- | Decomposes an algebraic map into homogenous degree-j maps for j from 0 to @d@
--
degreeDecomposition :: forall d f v . KnownNat (d+1) => [Poly f v Natural] -> V.Vector (d+1) [Poly f v Natural]
degreeDecomposition :: forall (d :: Natural) f v.
KnownNat (d + 1) =>
[Poly f v Natural] -> Vector (d + 1) [Poly f v Natural]
degreeDecomposition [Poly f v Natural]
lmap = (Rep (Vector (d + 1)) -> [Poly f v Natural])
-> Vector (d + 1) [Poly f v Natural]
forall a. (Rep (Vector (d + 1)) -> a) -> Vector (d + 1) a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (Natural -> [Poly f v Natural]
degree_j (Natural -> [Poly f v Natural])
-> (Zp (d + 1) -> Natural) -> Zp (d + 1) -> [Poly f v Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Zp (d + 1) -> Natural
Zp (d + 1) -> Const (Zp (d + 1))
forall a. ToConstant a => a -> Const a
toConstant)
    where
        degree_j :: Natural -> [Poly f v Natural]
        degree_j :: Natural -> [Poly f v Natural]
degree_j Natural
j = (Poly f v Natural -> Poly f v Natural)
-> [Poly f v Natural] -> [Poly f v Natural]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
P.fmap (Natural -> Poly f v Natural -> Poly f v Natural
leaveDeg Natural
j) [Poly f v Natural]
lmap

        leaveDeg :: Natural -> PM.Poly f v Natural -> PM.Poly f v Natural
        leaveDeg :: Natural -> Poly f v Natural -> Poly f v Natural
leaveDeg Natural
j (PM.P [(f, Mono v Natural)]
monomials) = [(f, Mono v Natural)] -> Poly f v Natural
forall c i j. [(c, Mono i j)] -> Poly c i j
PM.P ([(f, Mono v Natural)] -> Poly f v Natural)
-> [(f, Mono v Natural)] -> Poly f v Natural
forall a b. (a -> b) -> a -> b
$ ((f, Mono v Natural) -> Bool)
-> [(f, Mono v Natural)] -> [(f, Mono v Natural)]
forall a. (a -> Bool) -> [a] -> [a]
P.filter (\(f
_, Mono v Natural
m) -> Mono v Natural -> Natural
forall v. Mono v Natural -> Natural
deg Mono v Natural
m Natural -> Natural -> Bool
forall b a. Eq b a => a -> a -> b
== Natural
j) [(f, Mono v Natural)]
monomials

deg :: PM.Mono v Natural -> Natural
deg :: forall v. Mono v Natural -> Natural
deg (PM.M Map v Natural
m) = Map v Natural -> Natural
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum Map v Natural
m