module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
ArithmeticCircuit,
ArithmeticCircuitTest(..),
Constraint,
Var,
witnessGenerator,
optimize,
desugarRanges,
eval,
eval1,
exec,
exec1,
acSizeN,
acSizeM,
acSizeR,
acSystem,
acValue,
acPrint,
hlmap,
mapVarArithmeticCircuit,
acWitness,
acInput,
acOutput,
checkCircuit,
checkClosedCircuit
) where
import Control.Monad (foldM)
import Control.Monad.State (execState)
import Data.Binary (Binary)
import Data.Functor.Rep (Representable (..))
import Data.Map hiding (drop, foldl, foldr, map, null, splitAt,
take)
import Data.Void (absurd)
import GHC.Generics (U1 (..))
import Numeric.Natural (Natural)
import Prelude hiding (Num (..), drop, length, product, splitAt,
sum, take, (!!), (^))
import Test.QuickCheck (Arbitrary, Property, arbitrary, conjoin, property,
withMaxSuccess, (===))
import Text.Pretty.Simple (pPrint)
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial)
import ZkFold.Prelude (length)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance ()
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint,
SysVar (..), Var (..), acInput, eval, eval1, exec,
exec1, hlmap, witnessGenerator)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
import ZkFold.Symbolic.Data.Combinators (expansion)
import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..))
optimize :: ArithmeticCircuit a i o -> ArithmeticCircuit a i o
optimize :: forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
optimize = ArithmeticCircuit a i o -> ArithmeticCircuit a i o
forall a. a -> a
id
desugarRange :: (Arithmetic a, MonadCircuit i a m) => i -> a -> m ()
desugarRange :: forall a i (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a m) =>
i -> a -> m ()
desugarRange i
i a
b
| a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a -> a
forall a. AdditiveGroup a => a -> a
negate a
forall a. MultiplicativeMonoid a => a
one = () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
| Bool
otherwise = do
let bs :: Bits Natural
bs = Natural -> Bits Natural
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion (a -> Const a
forall a. ToConstant a => a -> Const a
toConstant a
b)
[i]
is <- Natural -> i -> m [i]
forall i a (m :: Type -> Type).
(MonadCircuit i a m, Arithmetic a) =>
Natural -> i -> m [i]
expansion ([Natural] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [Natural]
Bits Natural
bs) i
i
case ((Natural, i) -> Bool) -> [(Natural, i)] -> [(Natural, i)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
forall a. MultiplicativeMonoid a => a
one) (Natural -> Bool)
-> ((Natural, i) -> Natural) -> (Natural, i) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural, i) -> Natural
forall a b. (a, b) -> a
fst) ([Natural] -> [i] -> [(Natural, i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Natural]
Bits Natural
bs [i]
is) of
[] -> () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
((Natural
_, i
k0):[(Natural, i)]
ds) -> do
i
z <- ClosedPoly i a -> m i
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
ClosedPoly var a -> m var
newAssigned ((i -> x) -> x
forall a. MultiplicativeMonoid a => a
one ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
k0))
i
ge <- (i -> (Natural, i) -> m i) -> i -> [(Natural, i)] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\i
j (Natural
c, i
k) -> ClosedPoly i a -> m i
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
ClosedPoly var a -> m var
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ i -> Natural -> i -> (i -> x) -> x
forall {a} {b} {p}.
(Eq a, AdditiveGroup b, MultiplicativeMonoid b,
AdditiveMonoid a) =>
p -> a -> p -> (p -> b) -> b
forceGE i
j Natural
c i
k) i
z [(Natural, i)]
ds
ClosedPoly i a -> m ()
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
ClosedPoly var a -> m ()
constraint (((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
ge) ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- (i -> x) -> x
forall a. MultiplicativeMonoid a => a
one)
where forceGE :: p -> a -> p -> (p -> b) -> b
forceGE p
j a
c p
k
| a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. AdditiveMonoid a => a
zero = ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* ((p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k))
| Bool
otherwise = (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* (((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one)
desugarRanges ::
(Arithmetic a, Binary a, Binary (Rep i), Ord (Rep i), Representable i) =>
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
desugarRanges :: forall a (i :: Type -> Type) (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep i), Ord (Rep i),
Representable i) =>
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
desugarRanges ArithmeticCircuit a i o
c =
let r' :: ArithmeticCircuit a i U1
r' = (State (ArithmeticCircuit a i U1) [()]
-> ArithmeticCircuit a i U1 -> ArithmeticCircuit a i U1)
-> ArithmeticCircuit a i U1
-> State (ArithmeticCircuit a i U1) [()]
-> ArithmeticCircuit a i U1
forall a b c. (a -> b -> c) -> b -> a -> c
flip State (ArithmeticCircuit a i U1) [()]
-> ArithmeticCircuit a i U1 -> ArithmeticCircuit a i U1
forall s a. State s a -> s -> s
execState ArithmeticCircuit a i o
c {acOutput = U1} (State (ArithmeticCircuit a i U1) [()] -> ArithmeticCircuit a i U1)
-> ([(Var a i, a)] -> State (ArithmeticCircuit a i U1) [()])
-> [(Var a i, a)]
-> ArithmeticCircuit a i U1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Var a i, a) -> StateT (ArithmeticCircuit a i U1) Identity ())
-> [(Var a i, a)] -> State (ArithmeticCircuit a i U1) [()]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Var a i -> a -> StateT (ArithmeticCircuit a i U1) Identity ())
-> (Var a i, a) -> StateT (ArithmeticCircuit a i U1) Identity ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Var a i -> a -> StateT (ArithmeticCircuit a i U1) Identity ()
forall a i (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a m) =>
i -> a -> m ()
desugarRange) ([(Var a i, a)] -> ArithmeticCircuit a i U1)
-> [(Var a i, a)] -> ArithmeticCircuit a i U1
forall a b. (a -> b) -> a -> b
$ [(SysVar i -> Var a i
forall a (i :: Type -> Type). SysVar i -> Var a i
SysVar SysVar i
k, a
v) | (SysVar i
k,a
v) <- Map (SysVar i) a -> [(SysVar i, a)]
forall k a. Map k a -> [(k, a)]
toList (ArithmeticCircuit a i o -> Map (SysVar i) a
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Map (SysVar i) a
acRange ArithmeticCircuit a i o
c)]
in ArithmeticCircuit a i U1
r' { acRange = mempty, acOutput = acOutput c }
acSizeN :: ArithmeticCircuit a i o -> Natural
acSizeN :: forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeN = Map ByteString (Constraint a i) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map ByteString (Constraint a i) -> Natural)
-> (ArithmeticCircuit a i o -> Map ByteString (Constraint a i))
-> ArithmeticCircuit a i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a i o -> Map ByteString (Constraint a i)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Map ByteString (Constraint a i)
acSystem
acSizeM :: ArithmeticCircuit a i o -> Natural
acSizeM :: forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeM = Map ByteString (i a -> Map ByteString a -> a) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map ByteString (i a -> Map ByteString a -> a) -> Natural)
-> (ArithmeticCircuit a i o
-> Map ByteString (i a -> Map ByteString a -> a))
-> ArithmeticCircuit a i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a i o
-> Map ByteString (i a -> Map ByteString a -> a)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o
-> Map ByteString (i a -> Map ByteString a -> a)
acWitness
acSizeR :: ArithmeticCircuit a i o -> Natural
acSizeR :: forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeR = Map (SysVar i) a -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map (SysVar i) a -> Natural)
-> (ArithmeticCircuit a i o -> Map (SysVar i) a)
-> ArithmeticCircuit a i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a i o -> Map (SysVar i) a
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Map (SysVar i) a
acRange
acValue :: Functor o => ArithmeticCircuit a U1 o -> o a
acValue :: forall (o :: Type -> Type) a.
Functor o =>
ArithmeticCircuit a U1 o -> o a
acValue ArithmeticCircuit a U1 o
r = ArithmeticCircuit a U1 o -> U1 a -> o a
forall (i :: Type -> Type) (o :: Type -> Type) a.
(Representable i, Functor o) =>
ArithmeticCircuit a i o -> i a -> o a
eval ArithmeticCircuit a U1 o
r U1 a
forall k (p :: k). U1 p
U1
acPrint :: (Show a, Show (o (Var a U1)), Show (o a), Functor o) => ArithmeticCircuit a U1 o -> IO ()
acPrint :: forall a (o :: Type -> Type).
(Show a, Show (o (Var a U1)), Show (o a), Functor o) =>
ArithmeticCircuit a U1 o -> IO ()
acPrint ArithmeticCircuit a U1 o
ac = do
let m :: [Constraint a U1]
m = Map ByteString (Constraint a U1) -> [Constraint a U1]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a U1 o -> Map ByteString (Constraint a U1)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a U1 o
ac)
w :: Map ByteString a
w = ArithmeticCircuit a U1 o -> U1 a -> Map ByteString a
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a U1 o
ac U1 a
forall k (p :: k). U1 p
U1
v :: o a
v = ArithmeticCircuit a U1 o -> o a
forall (o :: Type -> Type) a.
Functor o =>
ArithmeticCircuit a U1 o -> o a
acValue ArithmeticCircuit a U1 o
ac
o :: o (Var a U1)
o = ArithmeticCircuit a U1 o -> o (Var a U1)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> o (Var a i)
acOutput ArithmeticCircuit a U1 o
ac
String -> IO ()
putStr String
"System size: "
Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a U1 o -> Natural
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeN ArithmeticCircuit a U1 o
ac
String -> IO ()
putStr String
"Variable size: "
Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a U1 o -> Natural
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeM ArithmeticCircuit a U1 o
ac
String -> IO ()
putStr String
"Matrices: "
[Constraint a U1] -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint [Constraint a U1]
m
String -> IO ()
putStr String
"Witness: "
Map ByteString a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint Map ByteString a
w
String -> IO ()
putStr String
"Output: "
o (Var a U1) -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint o (Var a U1)
o
String -> IO ()
putStr String
"Value: "
o a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint o a
v
checkClosedCircuit
:: forall a n
. Arithmetic a
=> Show a
=> ArithmeticCircuit a U1 n
-> Property
checkClosedCircuit :: forall a (n :: Type -> Type).
(Arithmetic a, Show a) =>
ArithmeticCircuit a U1 n -> Property
checkClosedCircuit ArithmeticCircuit a U1 n
c = Int -> Property -> Property
forall prop. Testable prop => Int -> prop -> Property
withMaxSuccess Int
1 (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Poly a (SysVar U1) Natural -> Property
testPoly Poly a (SysVar U1) Natural
p | Poly a (SysVar U1) Natural
p <- Map ByteString (Poly a (SysVar U1) Natural)
-> [Poly a (SysVar U1) Natural]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a U1 n
-> Map ByteString (Poly a (SysVar U1) Natural)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a U1 n
c) ]
where
w :: Map ByteString a
w = ArithmeticCircuit a U1 n -> U1 a -> Map ByteString a
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a U1 n
c U1 a
forall k (p :: k). U1 p
U1
testPoly :: Poly a (SysVar U1) Natural -> Property
testPoly Poly a (SysVar U1) Natural
p = ((SysVar U1 -> a) -> Mono (SysVar U1) Natural -> a)
-> (SysVar U1 -> a) -> Poly a (SysVar U1) Natural -> a
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar U1 -> a) -> Mono (SysVar U1) Natural -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar U1 -> a
varF Poly a (SysVar U1) Natural
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero
varF :: SysVar U1 -> a
varF (NewVar ByteString
v) = Map ByteString a
w Map ByteString a -> ByteString -> a
forall k a. Ord k => Map k a -> k -> a
! ByteString
v
varF (InVar Rep U1
v) = Void -> a
forall a. Void -> a
absurd Void
Rep U1
v
checkCircuit
:: Arbitrary (i a)
=> Arithmetic a
=> Show a
=> Representable i
=> ArithmeticCircuit a i n
-> Property
checkCircuit :: forall (i :: Type -> Type) a (n :: Type -> Type).
(Arbitrary (i a), Arithmetic a, Show a, Representable i) =>
ArithmeticCircuit a i n -> Property
checkCircuit ArithmeticCircuit a i n
c = [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Gen Property -> Property
forall prop. Testable prop => prop -> Property
property (Poly a (SysVar i) Natural -> Gen Property
testPoly Poly a (SysVar i) Natural
p) | Poly a (SysVar i) Natural
p <- Map ByteString (Poly a (SysVar i) Natural)
-> [Poly a (SysVar i) Natural]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a i n
-> Map ByteString (Poly a (SysVar i) Natural)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a i n
c) ]
where
testPoly :: Poly a (SysVar i) Natural -> Gen Property
testPoly Poly a (SysVar i) Natural
p = do
i a
ins <- Gen (i a)
forall a. Arbitrary a => Gen a
arbitrary
let w :: Map ByteString a
w = ArithmeticCircuit a i n -> i a -> Map ByteString a
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a i n
c i a
ins
varF :: SysVar i -> a
varF (NewVar ByteString
v) = Map ByteString a
w Map ByteString a -> ByteString -> a
forall k a. Ord k => Map k a -> k -> a
! ByteString
v
varF (InVar Rep i
v) = i a -> Rep i -> a
forall a. i a -> Rep i -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index i a
ins Rep i
v
Property -> Gen Property
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Property -> Gen Property) -> Property -> Gen Property
forall a b. (a -> b) -> a -> b
$ ((SysVar i -> a) -> Mono (SysVar i) Natural -> a)
-> (SysVar i -> a) -> Poly a (SysVar i) Natural -> a
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> a) -> Mono (SysVar i) Natural -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar i -> a
varF Poly a (SysVar i) Natural
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero