{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module ZkFold.Base.Protocol.Protostar.AlgebraicMap where
import Data.ByteString (ByteString)
import Data.Functor.Rep (Representable (..))
import Data.List (foldl')
import Data.Map.Strict (Map, keys)
import qualified Data.Map.Strict as M
import Prelude (fmap, zip, ($), (.), (<$>))
import qualified Prelude as P
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Algebra.Polynomials.Multivariate as PM
import ZkFold.Base.Algebra.Polynomials.Multivariate
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Base.Protocol.Protostar.ArithmetizableFunction (ArithmetizableFunction (..))
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import ZkFold.Symbolic.Data.Eq
class (Ring f) => AlgebraicMap f i (d :: Natural) a where
algebraicMap :: a
-> i f
-> Vector k [f]
-> Vector (k-1) f
-> f
-> [f]
instance
( Ring f
, Representable i
, KnownNat (d + 1)
, Arithmetic a
, Scale a f
) => AlgebraicMap f i d (ArithmetizableFunction a i p) where
algebraicMap :: forall (k :: Natural).
ArithmetizableFunction a i p
-> i f -> Vector k [f] -> Vector (k - 1) f -> f -> [f]
algebraicMap ArithmetizableFunction {ArithmeticCircuit a (i :*: p) i U1
i a -> p a -> i a
afEval :: i a -> p a -> i a
afCircuit :: ArithmeticCircuit a (i :*: p) i U1
afEval :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> i a -> p a -> i a
afCircuit :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> ArithmeticCircuit a (i :*: p) i U1
..} i f
pi Vector k [f]
pm Vector (k - 1) f
_ f
pad = f -> Vector (d + 1) [f] -> [f]
forall f (n :: Natural).
(MultiplicativeMonoid f, AdditiveMonoid f, KnownNat n) =>
f -> Vector n [f] -> [f]
padDecomposition f
pad Vector (d + 1) [f]
f_sps_uni
where
sys :: [PM.Poly a (SysVar i) Natural]
sys :: [Poly a (SysVar i) Natural]
sys = Map ByteString (Poly a (SysVar i) Natural)
-> [Poly a (SysVar i) Natural]
forall k a. Map k a -> [a]
M.elems (ArithmeticCircuit a (i :*: p) i U1
-> Map ByteString (Poly a (SysVar i) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a (i :*: p) i U1
afCircuit)
witness :: Map ByteString f
witness :: Map ByteString f
witness = [(ByteString, f)] -> Map ByteString f
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ByteString, f)] -> Map ByteString f)
-> [(ByteString, f)] -> Map ByteString f
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [f] -> [(ByteString, f)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Map ByteString (WitnessF a (WitVar (i :*: p) i)) -> [ByteString]
forall k a. Map k a -> [k]
keys (Map ByteString (WitnessF a (WitVar (i :*: p) i)) -> [ByteString])
-> Map ByteString (WitnessF a (WitVar (i :*: p) i)) -> [ByteString]
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a (i :*: p) i U1
-> Map ByteString (WitnessF a (WitVar (i :*: p) i))
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o
-> Map ByteString (WitnessF a (WitVar p i))
acWitness ArithmeticCircuit a (i :*: p) i U1
afCircuit) (Vector k [f] -> [f]
forall (size :: Natural) a. Vector size a -> a
V.head Vector k [f]
pm)
varMap :: SysVar i -> f
varMap :: SysVar i -> f
varMap (InVar Rep i
inV) = i f -> Rep i -> f
forall a. i a -> Rep i -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index i f
pi Rep i
inV
varMap (NewVar ByteString
newV) = f -> ByteString -> Map ByteString f -> f
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault f
forall a. AdditiveMonoid a => a
zero ByteString
newV Map ByteString f
witness
f_sps :: Vector (d+1) [PM.Poly a (SysVar i) Natural]
f_sps :: Vector (d + 1) [Poly a (SysVar i) Natural]
f_sps = forall f (d :: Natural) v.
KnownNat (d + 1) =>
[Poly f v Natural] -> Vector (d + 1) [Poly f v Natural]
degreeDecomposition @_ @d ([Poly a (SysVar i) Natural]
-> Vector (d + 1) [Poly a (SysVar i) Natural])
-> [Poly a (SysVar i) Natural]
-> Vector (d + 1) [Poly a (SysVar i) Natural]
forall a b. (a -> b) -> a -> b
$ [Poly a (SysVar i) Natural]
sys
f_sps_uni :: Vector (d+1) [f]
f_sps_uni :: Vector (d + 1) [f]
f_sps_uni = (Poly a (SysVar i) Natural -> f)
-> [Poly a (SysVar i) Natural] -> [f]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (((SysVar i -> f) -> Mono (SysVar i) Natural -> f)
-> (SysVar i -> f) -> Poly a (SysVar i) Natural -> f
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
PM.evalPolynomial (SysVar i -> f) -> Mono (SysVar i) Natural -> f
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
PM.evalMonomial SysVar i -> f
varMap) ([Poly a (SysVar i) Natural] -> [f])
-> Vector (d + 1) [Poly a (SysVar i) Natural] -> Vector (d + 1) [f]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (d + 1) [Poly a (SysVar i) Natural]
f_sps
padDecomposition :: forall f n .
( MultiplicativeMonoid f
, AdditiveMonoid f
, KnownNat n
) => f -> V.Vector n [f] -> [f]
padDecomposition :: forall f (n :: Natural).
(MultiplicativeMonoid f, AdditiveMonoid f, KnownNat n) =>
f -> Vector n [f] -> [f]
padDecomposition f
pad = ([f] -> [f] -> [f]) -> [f] -> Vector n [f] -> [f]
forall b a. (b -> a -> b) -> b -> Vector n a -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((f -> f -> f) -> [f] -> [f] -> [f]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
(+)) (f -> [f]
forall a. a -> [a]
P.repeat f
forall a. AdditiveMonoid a => a
zero) (Vector n [f] -> [f])
-> (Vector n [f] -> Vector n [f]) -> Vector n [f] -> [f]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> [f] -> [f]) -> Vector n [f] -> Vector n [f]
forall (n :: Natural) a b.
KnownNat n =>
(Natural -> a -> b) -> Vector n a -> Vector n b
V.mapWithIx (\Natural
j [f]
p -> ((f
pad f -> Natural -> f
forall a b. Exponent a b => a -> b -> a
^ (Natural
d Natural -> Natural -> Natural
-! Natural
j)) f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
* ) (f -> f) -> [f] -> [f]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [f]
p)
where
d :: Natural
d = forall (n :: Natural). KnownNat n => Natural
value @n Natural -> Natural -> Natural
-! Natural
1
degreeDecomposition :: forall f d v . KnownNat (d + 1) => [Poly f v Natural] -> V.Vector (d + 1) [Poly f v Natural]
degreeDecomposition :: forall f (d :: Natural) v.
KnownNat (d + 1) =>
[Poly f v Natural] -> Vector (d + 1) [Poly f v Natural]
degreeDecomposition [Poly f v Natural]
lmap = (Rep (Vector (d + 1)) -> [Poly f v Natural])
-> Vector (d + 1) [Poly f v Natural]
forall a. (Rep (Vector (d + 1)) -> a) -> Vector (d + 1) a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (Natural -> [Poly f v Natural]
degree_j (Natural -> [Poly f v Natural])
-> (Zp (d + 1) -> Natural) -> Zp (d + 1) -> [Poly f v Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Zp (d + 1) -> Natural
Zp (d + 1) -> Const (Zp (d + 1))
forall a. ToConstant a => a -> Const a
toConstant)
where
degree_j :: Natural -> [Poly f v Natural]
degree_j :: Natural -> [Poly f v Natural]
degree_j Natural
j = (Poly f v Natural -> Poly f v Natural)
-> [Poly f v Natural] -> [Poly f v Natural]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
P.fmap (Natural -> Poly f v Natural -> Poly f v Natural
leaveDeg Natural
j) [Poly f v Natural]
lmap
leaveDeg :: Natural -> PM.Poly f v Natural -> PM.Poly f v Natural
leaveDeg :: Natural -> Poly f v Natural -> Poly f v Natural
leaveDeg Natural
j (PM.P [(f, Mono v Natural)]
monomials) = [(f, Mono v Natural)] -> Poly f v Natural
forall c i j. [(c, Mono i j)] -> Poly c i j
PM.P ([(f, Mono v Natural)] -> Poly f v Natural)
-> [(f, Mono v Natural)] -> Poly f v Natural
forall a b. (a -> b) -> a -> b
$ ((f, Mono v Natural) -> Bool)
-> [(f, Mono v Natural)] -> [(f, Mono v Natural)]
forall a. (a -> Bool) -> [a] -> [a]
P.filter (\(f
_, Mono v Natural
m) -> Mono v Natural -> Natural
forall v. Mono v Natural -> Natural
deg Mono v Natural
m Natural -> Natural -> Bool
forall b a. Eq b a => a -> a -> b
== Natural
j) [(f, Mono v Natural)]
monomials
deg :: PM.Mono v Natural -> Natural
deg :: forall v. Mono v Natural -> Natural
deg (PM.M Map v Natural
m) = Map v Natural -> Natural
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum Map v Natural
m