{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE UndecidableInstances #-}
module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
mapVarArithmeticCircuit,
ArithmeticCircuitTest(..)
) where
import Data.Functor.Rep (Representable (..))
import Data.Map hiding (drop, foldl, foldr, fromList, map, null,
splitAt, take, toList)
import qualified Data.Map as Map
import GHC.IsList (IsList (..))
import Numeric.Natural (Natural)
import Prelude hiding (Num (..), drop, length, product, splitAt,
sum, take, (!!), (^))
import Test.QuickCheck (Arbitrary (arbitrary), Gen)
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Polynomials.Multivariate
import ZkFold.Base.Data.ByteString (toByteString)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), SysVar (..),
Var (..), VarField, getAllVars)
data ArithmeticCircuitTest a i o = ArithmeticCircuitTest
{
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuitTest a i o -> ArithmeticCircuit a i o
arithmeticCircuit :: ArithmeticCircuit a i o
, forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuitTest a i o -> i a
witnessInput :: i a
}
instance (Show (ArithmeticCircuit a i o), Show a, Show (i a)) => Show (ArithmeticCircuitTest a i o) where
show :: ArithmeticCircuitTest a i o -> String
show (ArithmeticCircuitTest ArithmeticCircuit a i o
ac i a
wi) = ArithmeticCircuit a i o -> String
forall a. Show a => a -> String
show ArithmeticCircuit a i o
ac String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
",\nwitnessInput: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ i a -> String
forall a. Show a => a -> String
show i a
wi
instance (Arithmetic a, Arbitrary (i a), Arbitrary (ArithmeticCircuit a i f), Representable i) => Arbitrary (ArithmeticCircuitTest a i f) where
arbitrary :: Gen (ArithmeticCircuitTest a i f)
arbitrary :: Gen (ArithmeticCircuitTest a i f)
arbitrary = do
ArithmeticCircuit a i f
ac <- Gen (ArithmeticCircuit a i f)
forall a. Arbitrary a => Gen a
arbitrary
i a
wi <- Gen (i a)
forall a. Arbitrary a => Gen a
arbitrary
ArithmeticCircuitTest a i f -> Gen (ArithmeticCircuitTest a i f)
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ArithmeticCircuitTest {
arithmeticCircuit :: ArithmeticCircuit a i f
arithmeticCircuit = ArithmeticCircuit a i f
ac
, witnessInput :: i a
witnessInput = i a
wi
}
mapVarArithmeticCircuit :: (Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o
mapVarArithmeticCircuit :: forall a (o :: Type -> Type) (i :: Type -> Type).
(Field a, Eq a, Functor o, Ord (Rep i), Representable i,
Foldable i) =>
ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o
mapVarArithmeticCircuit (ArithmeticCircuitTest ArithmeticCircuit a i o
ac i a
wi) =
let vars :: [ByteString]
vars = [ByteString
v | NewVar ByteString
v <- ArithmeticCircuit a i o -> [SysVar i]
forall a (i :: Type -> Type) (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a i o -> [SysVar i]
getAllVars ArithmeticCircuit a i o
ac]
asc :: [ByteString]
asc = [ forall a. Binary a => a -> ByteString
toByteString @VarField (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
x) | Natural
x <- [Natural
0..] ]
forward :: Map ByteString ByteString
forward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
vars [ByteString]
asc
backward :: Map ByteString ByteString
backward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc [ByteString]
vars
varF :: SysVar i -> SysVar i
varF (InVar Rep i
v) = Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar Rep i
v
varF (NewVar ByteString
v) = ByteString -> SysVar i
forall (i :: Type -> Type). ByteString -> SysVar i
NewVar (Map ByteString ByteString
forward Map ByteString ByteString -> ByteString -> ByteString
forall k a. Ord k => Map k a -> k -> a
! ByteString
v)
mappedCircuit :: ArithmeticCircuit a i o
mappedCircuit = ArithmeticCircuit a i o
ac
{
acSystem = fromList $ zip asc $ evalPolynomial evalMonomial (var . varF) <$> elems (acSystem ac),
acWitness = (`Map.compose` backward) $ (\i a -> Map ByteString a -> a
f i a
i Map ByteString a
m -> i a -> Map ByteString a -> a
f i a
i (Map ByteString a -> Map ByteString ByteString -> Map ByteString a
forall b c a. Ord b => Map b c -> Map a b -> Map a c
Map.compose Map ByteString a
m Map ByteString ByteString
forward)) <$> acWitness ac
}
varG :: Var a i -> Var a i
varG = \case
SysVar SysVar i
v -> SysVar i -> Var a i
forall a (i :: Type -> Type). SysVar i -> Var a i
SysVar (SysVar i -> SysVar i
varF SysVar i
v)
ConstVar a
c -> a -> Var a i
forall a (i :: Type -> Type). a -> Var a i
ConstVar a
c
mappedOutputs :: o (Var a i)
mappedOutputs = Var a i -> Var a i
varG (Var a i -> Var a i) -> o (Var a i) -> o (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a i o -> o (Var a i)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> o (Var a i)
acOutput ArithmeticCircuit a i o
ac
in ArithmeticCircuit a i o -> i a -> ArithmeticCircuitTest a i o
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> i a -> ArithmeticCircuitTest a i o
ArithmeticCircuitTest (ArithmeticCircuit a i o
mappedCircuit {acOutput = mappedOutputs}) i a
wi