{-# LANGUAGE TypeOperators #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
        ArithmeticCircuit,
        Constraint,
        Var,
        witnessGenerator,
        -- high-level functions
        optimize,
        desugarRanges,
        emptyCircuit,
        idCircuit,
        naturalCircuit,
        inputPayload,
        guessOutput,
        -- low-level functions
        eval,
        eval1,
        exec,
        exec1,
        -- information about the system
        acSizeN,
        acSizeM,
        acSizeR,
        acSystem,
        acValue,
        acPrint,
        -- Variable mapping functions
        hlmap,
        hpmap,
        mapVarArithmeticCircuit,
        -- Arithmetization type fields
        acWitness,
        acInput,
        acOutput,
        -- Testing functions
        checkCircuit,
        checkClosedCircuit,
        isConstantInput
    ) where

import           Control.DeepSeq                                         (NFData)
import           Control.Monad                                           (foldM)
import           Control.Monad.State                                     (execState, runState)
import           Data.Binary                                             (Binary)
import           Data.Foldable                                           (for_)
import           Data.Functor.Rep                                        (Representable (..), mzipRep)
import           Data.Map                                                hiding (drop, foldl, foldr, map, null, splitAt,
                                                                          take)
import qualified Data.Map.Monoidal                                       as M
import qualified Data.Set                                                as S
import           Data.Traversable                                        (for)
import           Data.Tuple                                              (swap)
import           Data.Void                                               (absurd)
import           GHC.Generics                                            (U1 (..), (:*:))
import           Numeric.Natural                                         (Natural)
import           Prelude                                                 hiding (Num (..), drop, length, product,
                                                                          splitAt, sum, take, (!!), (^))
import           Test.QuickCheck                                         (Arbitrary, Property, arbitrary, conjoin,
                                                                          property, withMaxSuccess, (===))
import           Text.Pretty.Simple                                      (pPrint)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Polynomials.Multivariate            (evalMonomial, evalPolynomial)
import           ZkFold.Base.Data.HFunctor                               (hmap)
import           ZkFold.Base.Data.Product                                (fstP, sndP)
import           ZkFold.Prelude                                          (length)
import           ZkFold.Symbolic.Class                                   (fromCircuit2F)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance     ()
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal     (Arithmetic, ArithmeticCircuit (..),
                                                                          Constraint, SysVar (..), Var (..),
                                                                          WitVar (..), acInput, crown, eval, eval1,
                                                                          exec, exec1, hlmap, hpmap, witnessGenerator)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Optimization
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Var          (toVar)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness      (WitnessF)
import           ZkFold.Symbolic.Data.Combinators                        (expansion)
import           ZkFold.Symbolic.MonadCircuit                            (MonadCircuit (..))

--------------------------------- High-level functions --------------------------------

desugarRange :: (Arithmetic a, MonadCircuit i a w m) => i -> a -> m ()
desugarRange :: forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
i -> a -> m ()
desugarRange i
i a
b
  | a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a -> a
forall a. AdditiveGroup a => a -> a
negate a
forall a. MultiplicativeMonoid a => a
one = () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do
    let bs :: Bits Natural
bs = Natural -> Bits Natural
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion (a -> Const a
forall a. ToConstant a => a -> Const a
toConstant a
b)
    [i]
is <- Natural -> i -> m [i]
forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansion ([Natural] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [Natural]
Bits Natural
bs) i
i
    case ((Natural, i) -> Bool) -> [(Natural, i)] -> [(Natural, i)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
forall a. MultiplicativeMonoid a => a
one) (Natural -> Bool)
-> ((Natural, i) -> Natural) -> (Natural, i) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural, i) -> Natural
forall a b. (a, b) -> a
fst) ([Natural] -> [i] -> [(Natural, i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Natural]
Bits Natural
bs [i]
is) of
      [] -> () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
      ((Natural
_, i
k0):[(Natural, i)]
ds) -> do
        i
z <- ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned ((i -> x) -> x
forall a. MultiplicativeMonoid a => a
one ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
k0))
        i
ge <- (i -> (Natural, i) -> m i) -> i -> [(Natural, i)] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\i
j (Natural
c, i
k) -> ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ i -> Natural -> i -> (i -> x) -> x
forall {a} {b} {p}.
(Eq a, AdditiveGroup b, MultiplicativeMonoid b,
 AdditiveMonoid a) =>
p -> a -> p -> (p -> b) -> b
forceGE i
j Natural
c i
k) i
z [(Natural, i)]
ds
        ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
ge) ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- (i -> x) -> x
forall a. MultiplicativeMonoid a => a
one)
  where forceGE :: p -> a -> p -> (p -> b) -> b
forceGE p
j a
c p
k
          | a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. AdditiveMonoid a => a
zero = ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* ((p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k))
          | Bool
otherwise = (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* (((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one)

-- | Desugars range constraints into polynomial constraints
desugarRanges ::
  (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
desugarRanges :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i)) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
desugarRanges ArithmeticCircuit a p i o
c =
  let r' :: ArithmeticCircuit a p i U1
r' = (State (ArithmeticCircuit a p i U1) [()]
 -> ArithmeticCircuit a p i U1 -> ArithmeticCircuit a p i U1)
-> ArithmeticCircuit a p i U1
-> State (ArithmeticCircuit a p i U1) [()]
-> ArithmeticCircuit a p i U1
forall a b c. (a -> b -> c) -> b -> a -> c
flip State (ArithmeticCircuit a p i U1) [()]
-> ArithmeticCircuit a p i U1 -> ArithmeticCircuit a p i U1
forall s a. State s a -> s -> s
execState ArithmeticCircuit a p i o
c {acOutput = U1} (State (ArithmeticCircuit a p i U1) [()]
 -> ArithmeticCircuit a p i U1)
-> ([(Var a i, a)] -> State (ArithmeticCircuit a p i U1) [()])
-> [(Var a i, a)]
-> ArithmeticCircuit a p i U1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Var a i, a) -> StateT (ArithmeticCircuit a p i U1) Identity ())
-> [(Var a i, a)] -> State (ArithmeticCircuit a p i U1) [()]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Var a i -> a -> StateT (ArithmeticCircuit a p i U1) Identity ())
-> (Var a i, a) -> StateT (ArithmeticCircuit a p i U1) Identity ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Var a i -> a -> StateT (ArithmeticCircuit a p i U1) Identity ()
forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
i -> a -> m ()
desugarRange) ([(Var a i, a)] -> ArithmeticCircuit a p i U1)
-> [(Var a i, a)] -> ArithmeticCircuit a p i U1
forall a b. (a -> b) -> a -> b
$ [(SysVar i -> Var a i
forall a (i :: Type -> Type). Semiring a => SysVar i -> Var a i
toVar SysVar i
v, a
k) | (a
k, Set (SysVar i)
s) <- MonoidalMap a (Set (SysVar i)) -> [(a, Set (SysVar i))]
forall k a. MonoidalMap k a -> [(k, a)]
M.toList (ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
acRange ArithmeticCircuit a p i o
c), SysVar i
v <- Set (SysVar i) -> [SysVar i]
forall a. Set a -> [a]
S.toList Set (SysVar i)
s]
   in ArithmeticCircuit a p i U1
r' { acRange = mempty, acOutput = acOutput c }

emptyCircuit :: ArithmeticCircuit a p i U1
emptyCircuit :: forall a (p :: Type -> Type) (i :: Type -> Type).
ArithmeticCircuit a p i U1
emptyCircuit = Map ByteString (Constraint a i)
-> MonoidalMap a (Set (SysVar i))
-> Map ByteString (CircuitWitness a p i)
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
-> U1 (Var a i)
-> ArithmeticCircuit a p i U1
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
Map ByteString (Constraint a i)
-> MonoidalMap a (Set (SysVar i))
-> Map ByteString (CircuitWitness a p i)
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
-> o (Var a i)
-> ArithmeticCircuit a p i o
ArithmeticCircuit Map ByteString (Constraint a i)
forall k a. Map k a
empty MonoidalMap a (Set (SysVar i))
forall k a. MonoidalMap k a
M.empty Map ByteString (CircuitWitness a p i)
forall k a. Map k a
empty Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall k a. Map k a
empty U1 (Var a i)
forall k (p :: k). U1 p
U1

-- | Given a natural transformation
-- from payload @p@ and input @i@ to output @o@,
-- returns a corresponding arithmetic circuit
-- where outputs computing the payload are unconstrained.
naturalCircuit ::
  ( Arithmetic a, Representable p, Representable i, Traversable o
  , Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
  (forall x. p x -> i x -> o x) -> ArithmeticCircuit a p i o
naturalCircuit :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i, Traversable o,
 Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
(forall x. p x -> i x -> o x) -> ArithmeticCircuit a p i o
naturalCircuit forall x. p x -> i x -> o x
f = (ArithmeticCircuit a p i U1
 -> o (Var a i) -> ArithmeticCircuit a p i o)
-> (ArithmeticCircuit a p i U1, o (Var a i))
-> ArithmeticCircuit a p i o
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ArithmeticCircuit a p i U1
-> o (Var a i) -> ArithmeticCircuit a p i o
forall a (p :: Type -> Type) (i :: Type -> Type)
       (g :: Type -> Type) (f :: Type -> Type).
ArithmeticCircuit a p i g
-> f (Var a i) -> ArithmeticCircuit a p i f
crown ((ArithmeticCircuit a p i U1, o (Var a i))
 -> ArithmeticCircuit a p i o)
-> (ArithmeticCircuit a p i U1, o (Var a i))
-> ArithmeticCircuit a p i o
forall a b. (a -> b) -> a -> b
$ (o (Var a i), ArithmeticCircuit a p i U1)
-> (ArithmeticCircuit a p i U1, o (Var a i))
forall a b. (a, b) -> (b, a)
swap ((o (Var a i), ArithmeticCircuit a p i U1)
 -> (ArithmeticCircuit a p i U1, o (Var a i)))
-> (o (Var a i), ArithmeticCircuit a p i U1)
-> (ArithmeticCircuit a p i U1, o (Var a i))
forall a b. (a -> b) -> a -> b
$ (State (ArithmeticCircuit a p i U1) (o (Var a i))
 -> ArithmeticCircuit a p i U1
 -> (o (Var a i), ArithmeticCircuit a p i U1))
-> ArithmeticCircuit a p i U1
-> State (ArithmeticCircuit a p i U1) (o (Var a i))
-> (o (Var a i), ArithmeticCircuit a p i U1)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State (ArithmeticCircuit a p i U1) (o (Var a i))
-> ArithmeticCircuit a p i U1
-> (o (Var a i), ArithmeticCircuit a p i U1)
forall s a. State s a -> s -> (a, s)
runState ArithmeticCircuit a p i U1
forall a (p :: Type -> Type) (i :: Type -> Type).
ArithmeticCircuit a p i U1
emptyCircuit (State (ArithmeticCircuit a p i U1) (o (Var a i))
 -> (o (Var a i), ArithmeticCircuit a p i U1))
-> State (ArithmeticCircuit a p i U1) (o (Var a i))
-> (o (Var a i), ArithmeticCircuit a p i U1)
forall a b. (a -> b) -> a -> b
$
  o (Either (Rep p) (Rep i))
-> (Either (Rep p) (Rep i)
    -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
-> State (ArithmeticCircuit a p i U1) (o (Var a i))
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (p (Either (Rep p) (Rep i))
-> i (Either (Rep p) (Rep i)) -> o (Either (Rep p) (Rep i))
forall x. p x -> i x -> o x
f ((Rep p -> Either (Rep p) (Rep i)) -> p (Either (Rep p) (Rep i))
forall a. (Rep p -> a) -> p a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate Rep p -> Either (Rep p) (Rep i)
forall a b. a -> Either a b
Left) ((Rep i -> Either (Rep p) (Rep i)) -> i (Either (Rep p) (Rep i))
forall a. (Rep i -> a) -> i a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate Rep i -> Either (Rep p) (Rep i)
forall a b. b -> Either a b
Right)) ((Either (Rep p) (Rep i)
  -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
 -> State (ArithmeticCircuit a p i U1) (o (Var a i)))
-> (Either (Rep p) (Rep i)
    -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
-> State (ArithmeticCircuit a p i U1) (o (Var a i))
forall a b. (a -> b) -> a -> b
$
    (Rep p -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
-> (Rep i
    -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
-> Either (Rep p) (Rep i)
-> StateT (ArithmeticCircuit a p i U1) Identity (Var a i)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (WitnessF a (WitVar p i)
-> StateT (ArithmeticCircuit a p i U1) Identity (Var a i)
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
w -> m var
unconstrained (WitnessF a (WitVar p i)
 -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
-> (Rep p -> WitnessF a (WitVar p i))
-> Rep p
-> StateT (ArithmeticCircuit a p i U1) Identity (Var a i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (WitVar p i -> WitnessF a (WitVar p i))
-> (Rep p -> WitVar p i) -> Rep p -> WitnessF a (WitVar p i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep p -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type). Rep p -> WitVar p i
WExVar) (Var a i -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i)
forall a. a -> StateT (ArithmeticCircuit a p i U1) Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Var a i -> StateT (ArithmeticCircuit a p i U1) Identity (Var a i))
-> (Rep i -> Var a i)
-> Rep i
-> StateT (ArithmeticCircuit a p i U1) Identity (Var a i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> Var a i
forall a (i :: Type -> Type). Semiring a => SysVar i -> Var a i
toVar (SysVar i -> Var a i) -> (Rep i -> SysVar i) -> Rep i -> Var a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar)

-- | Identity circuit which returns its input @i@ and doesn't use the payload.
idCircuit :: (Representable i, Semiring a) => ArithmeticCircuit a p i i
idCircuit :: forall (i :: Type -> Type) a (p :: Type -> Type).
(Representable i, Semiring a) =>
ArithmeticCircuit a p i i
idCircuit = ArithmeticCircuit a p i U1
forall a (p :: Type -> Type) (i :: Type -> Type).
ArithmeticCircuit a p i U1
emptyCircuit { acOutput = acInput }

-- | Payload of an input to arithmetic circuit.
-- To be used as an argument to 'compileWith'.
inputPayload ::
  (Representable p, Representable i) =>
  (forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload :: forall (p :: Type -> Type) (i :: Type -> Type) (o :: Type -> Type)
       a.
(Representable p, Representable i) =>
(forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload forall x. p x -> i x -> o x
f =
  p (WitnessF a (WitVar p i))
-> i (WitnessF a (WitVar p i)) -> o (WitnessF a (WitVar p i))
forall x. p x -> i x -> o x
f ((Rep p -> WitnessF a (WitVar p i)) -> p (WitnessF a (WitVar p i))
forall a. (Rep p -> a) -> p a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate ((Rep p -> WitnessF a (WitVar p i)) -> p (WitnessF a (WitVar p i)))
-> (Rep p -> WitnessF a (WitVar p i))
-> p (WitnessF a (WitVar p i))
forall a b. (a -> b) -> a -> b
$ WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (WitVar p i -> WitnessF a (WitVar p i))
-> (Rep p -> WitVar p i) -> Rep p -> WitnessF a (WitVar p i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep p -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type). Rep p -> WitVar p i
WExVar) ((Rep i -> WitnessF a (WitVar p i)) -> i (WitnessF a (WitVar p i))
forall a. (Rep i -> a) -> i a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate ((Rep i -> WitnessF a (WitVar p i)) -> i (WitnessF a (WitVar p i)))
-> (Rep i -> WitnessF a (WitVar p i))
-> i (WitnessF a (WitVar p i))
forall a b. (a -> b) -> a -> b
$ WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (WitVar p i -> WitnessF a (WitVar p i))
-> (Rep i -> WitVar p i) -> Rep i -> WitnessF a (WitVar p i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar (SysVar i -> WitVar p i)
-> (Rep i -> SysVar i) -> Rep i -> WitVar p i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar)

guessOutput ::
  (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Binary (Rep o)) =>
  (Ord (Rep i), Ord (Rep o), NFData (Rep i), NFData (Rep o)) =>
  (Representable i, Representable o, Foldable o) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) U1
guessOutput :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Binary (Rep o), Ord (Rep i), Ord (Rep o), NFData (Rep i),
 NFData (Rep o), Representable i, Representable o, Foldable o) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) U1
guessOutput ArithmeticCircuit a p i o
c = ArithmeticCircuit a p (i :*: o) o
-> ArithmeticCircuit a p (i :*: o) o
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (ArithmeticCircuit a p (i :*: o)))
       (WitnessField (ArithmeticCircuit a p (i :*: o)))
       m) =>
    FunBody '[o, o] U1 i m)
-> ArithmeticCircuit a p (i :*: o) U1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F ((forall x. (:*:) i o x -> i x)
-> ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) o
forall (i :: Type -> Type) (j :: Type -> Type) (o :: Type -> Type)
       a (p :: Type -> Type).
(Representable i, Representable j, Ord (Rep j), Functor o) =>
(forall x. j x -> i x)
-> ArithmeticCircuit a p i o -> ArithmeticCircuit a p j o
hlmap (:*:) i o x -> i x
forall x. (:*:) i o x -> i x
forall {k} (f :: k -> Type) (g :: k -> Type) (a :: k).
(:*:) f g a -> f a
fstP ArithmeticCircuit a p i o
c) ((forall a. (:*:) i o a -> o a)
-> ArithmeticCircuit a p (i :*: o) (i :*: o)
-> ArithmeticCircuit a p (i :*: o) o
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a)
-> ArithmeticCircuit a p (i :*: o) f
-> ArithmeticCircuit a p (i :*: o) g
hmap (:*:) i o a -> o a
forall a. (:*:) i o a -> o a
forall {k} (f :: k -> Type) (g :: k -> Type) (a :: k).
(:*:) f g a -> g a
sndP ArithmeticCircuit a p (i :*: o) (i :*: o)
forall (i :: Type -> Type) a (p :: Type -> Type).
(Representable i, Semiring a) =>
ArithmeticCircuit a p i i
idCircuit) ((forall {i} {m :: Type -> Type}.
  (NFData i,
   MonadCircuit
     i
     (BaseField (ArithmeticCircuit a p (i :*: o)))
     (WitnessField (ArithmeticCircuit a p (i :*: o)))
     m) =>
  FunBody '[o, o] U1 i m)
 -> ArithmeticCircuit a p (i :*: o) U1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (ArithmeticCircuit a p (i :*: o)))
       (WitnessField (ArithmeticCircuit a p (i :*: o)))
       m) =>
    FunBody '[o, o] U1 i m)
-> ArithmeticCircuit a p (i :*: o) U1
forall a b. (a -> b) -> a -> b
$ \o i
o o i
o' -> do
  o (i, i) -> ((i, i) -> m ()) -> m ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (o i -> o i -> o (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep o i
o o i
o') (((i, i) -> m ()) -> m ()) -> ((i, i) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) -> ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
j)
  U1 i -> m (U1 i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return U1 i
forall k (p :: k). U1 p
U1

----------------------------------- Information -----------------------------------

-- | Calculates the number of constraints in the system.
acSizeN :: ArithmeticCircuit a p i o -> Natural
acSizeN :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeN = Map ByteString (Constraint a i) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map ByteString (Constraint a i) -> Natural)
-> (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i))
-> ArithmeticCircuit a p i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem

-- | Calculates the number of variables in the system.
acSizeM :: ArithmeticCircuit a p i o -> Natural
acSizeM :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeM = Map ByteString (CircuitWitness a p i) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map ByteString (CircuitWitness a p i) -> Natural)
-> (ArithmeticCircuit a p i o
    -> Map ByteString (CircuitWitness a p i))
-> ArithmeticCircuit a p i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
acWitness

-- | Calculates the number of range lookups in the system.
acSizeR :: ArithmeticCircuit a p i o -> Natural
acSizeR :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeR = [Natural] -> Natural
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([Natural] -> Natural)
-> (ArithmeticCircuit a p i o -> [Natural])
-> ArithmeticCircuit a p i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Set (SysVar i) -> Natural) -> [Set (SysVar i)] -> [Natural]
forall a b. (a -> b) -> [a] -> [b]
map Set (SysVar i) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length ([Set (SysVar i)] -> [Natural])
-> (ArithmeticCircuit a p i o -> [Set (SysVar i)])
-> ArithmeticCircuit a p i o
-> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MonoidalMap a (Set (SysVar i)) -> [Set (SysVar i)]
forall k a. MonoidalMap k a -> [a]
M.elems (MonoidalMap a (Set (SysVar i)) -> [Set (SysVar i)])
-> (ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i)))
-> ArithmeticCircuit a p i o
-> [Set (SysVar i)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
acRange

acValue :: (Arithmetic a, Functor o) => ArithmeticCircuit a U1 U1 o -> o a
acValue :: forall a (o :: Type -> Type).
(Arithmetic a, Functor o) =>
ArithmeticCircuit a U1 U1 o -> o a
acValue = ArithmeticCircuit a U1 U1 o -> o a
forall a (o :: Type -> Type).
(Arithmetic a, Functor o) =>
ArithmeticCircuit a U1 U1 o -> o a
exec

-- | Prints the constraint system, the witness, and the output.
--
-- TODO: Move this elsewhere (?)
-- TODO: Check that all arguments have been applied.
acPrint ::
  (Arithmetic a, Show a, Show (o (Var a U1)), Show (o a), Functor o) =>
  ArithmeticCircuit a U1 U1 o -> IO ()
acPrint :: forall a (o :: Type -> Type).
(Arithmetic a, Show a, Show (o (Var a U1)), Show (o a),
 Functor o) =>
ArithmeticCircuit a U1 U1 o -> IO ()
acPrint ArithmeticCircuit a U1 U1 o
ac = do
    let m :: [Constraint a U1]
m = Map ByteString (Constraint a U1) -> [Constraint a U1]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a U1 U1 o -> Map ByteString (Constraint a U1)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a U1 U1 o
ac)
        w :: Map ByteString a
w = ArithmeticCircuit a U1 U1 o -> U1 a -> U1 a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a U1 U1 o
ac U1 a
forall k (p :: k). U1 p
U1 U1 a
forall k (p :: k). U1 p
U1
        v :: o a
v = ArithmeticCircuit a U1 U1 o -> o a
forall a (o :: Type -> Type).
(Arithmetic a, Functor o) =>
ArithmeticCircuit a U1 U1 o -> o a
acValue ArithmeticCircuit a U1 U1 o
ac
        o :: o (Var a U1)
o = ArithmeticCircuit a U1 U1 o -> o (Var a U1)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a U1 U1 o
ac
    String -> IO ()
putStr String
"System size: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a U1 U1 o -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeN ArithmeticCircuit a U1 U1 o
ac
    String -> IO ()
putStr String
"Variable size: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a U1 U1 o -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeM ArithmeticCircuit a U1 U1 o
ac
    String -> IO ()
putStr String
"Matrices: "
    [Constraint a U1] -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint [Constraint a U1]
m
    String -> IO ()
putStr String
"Witness: "
    Map ByteString a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint Map ByteString a
w
    String -> IO ()
putStr String
"Output: "
    o (Var a U1) -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint o (Var a U1)
o
    String -> IO ()
putStr String
"Value: "
    o a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint o a
v

---------------------------------- Testing -------------------------------------

isConstantInput ::
  ( Arithmetic a, Show a, Representable p, Representable i
  , Show (p a), Show (i a), Arbitrary (p a), Arbitrary (i a)
  ) => ArithmeticCircuit a p i o -> Property
isConstantInput :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Show a, Representable p, Representable i,
 Show (p a), Show (i a), Arbitrary (p a), Arbitrary (i a)) =>
ArithmeticCircuit a p i o -> Property
isConstantInput ArithmeticCircuit a p i o
c = (i a -> i a -> p a -> Property) -> Property
forall prop. Testable prop => prop -> Property
property ((i a -> i a -> p a -> Property) -> Property)
-> (i a -> i a -> p a -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \i a
x i a
y p a
p -> ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a p i o
c p a
p i a
x Map ByteString a -> Map ByteString a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a p i o
c p a
p i a
y

checkClosedCircuit
    :: forall a o
     . Arithmetic a
    => Show a
    => ArithmeticCircuit a U1 U1 o
    -> Property
checkClosedCircuit :: forall a (o :: Type -> Type).
(Arithmetic a, Show a) =>
ArithmeticCircuit a U1 U1 o -> Property
checkClosedCircuit ArithmeticCircuit a U1 U1 o
c = Int -> Property -> Property
forall prop. Testable prop => Int -> prop -> Property
withMaxSuccess Int
1 (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Poly a (SysVar U1) Natural -> Property
testPoly Poly a (SysVar U1) Natural
p | Poly a (SysVar U1) Natural
p <- Map ByteString (Poly a (SysVar U1) Natural)
-> [Poly a (SysVar U1) Natural]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a U1 U1 o
-> Map ByteString (Poly a (SysVar U1) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a U1 U1 o
c) ]
    where
        w :: Map ByteString a
w = ArithmeticCircuit a U1 U1 o -> U1 a -> U1 a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a U1 U1 o
c U1 a
forall k (p :: k). U1 p
U1 U1 a
forall k (p :: k). U1 p
U1
        testPoly :: Poly a (SysVar U1) Natural -> Property
testPoly Poly a (SysVar U1) Natural
p = ((SysVar U1 -> a) -> Mono (SysVar U1) Natural -> a)
-> (SysVar U1 -> a) -> Poly a (SysVar U1) Natural -> a
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar U1 -> a) -> Mono (SysVar U1) Natural -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar U1 -> a
varF Poly a (SysVar U1) Natural
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero
        varF :: SysVar U1 -> a
varF (NewVar ByteString
v) = Map ByteString a
w Map ByteString a -> ByteString -> a
forall k a. Ord k => Map k a -> k -> a
! ByteString
v
        varF (InVar Rep U1
v)  = Void -> a
forall a. Void -> a
absurd Void
Rep U1
v

checkCircuit
    :: Arbitrary (p a)
    => Arbitrary (i a)
    => Arithmetic a
    => Show a
    => Representable p
    => Representable i
    => ArithmeticCircuit a p i o
    -> Property
checkCircuit :: forall (p :: Type -> Type) a (i :: Type -> Type)
       (o :: Type -> Type).
(Arbitrary (p a), Arbitrary (i a), Arithmetic a, Show a,
 Representable p, Representable i) =>
ArithmeticCircuit a p i o -> Property
checkCircuit ArithmeticCircuit a p i o
c = [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Gen Property -> Property
forall prop. Testable prop => prop -> Property
property (Poly a (SysVar i) Natural -> Gen Property
testPoly Poly a (SysVar i) Natural
p) | Poly a (SysVar i) Natural
p <- Map ByteString (Poly a (SysVar i) Natural)
-> [Poly a (SysVar i) Natural]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a p i o
-> Map ByteString (Poly a (SysVar i) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
c) ]
    where
        testPoly :: Poly a (SysVar i) Natural -> Gen Property
testPoly Poly a (SysVar i) Natural
p = do
            i a
ins <- Gen (i a)
forall a. Arbitrary a => Gen a
arbitrary
            p a
pls <- Gen (p a)
forall a. Arbitrary a => Gen a
arbitrary
            let w :: Map ByteString a
w = ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a p i o
c p a
pls i a
ins
                varF :: SysVar i -> a
varF (NewVar ByteString
v) = Map ByteString a
w Map ByteString a -> ByteString -> a
forall k a. Ord k => Map k a -> k -> a
! ByteString
v
                varF (InVar Rep i
v)  = i a -> Rep i -> a
forall a. i a -> Rep i -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index i a
ins Rep i
v
            Property -> Gen Property
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Property -> Gen Property) -> Property -> Gen Property
forall a b. (a -> b) -> a -> b
$ ((SysVar i -> a) -> Mono (SysVar i) Natural -> a)
-> (SysVar i -> a) -> Poly a (SysVar i) Natural -> a
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> a) -> Mono (SysVar i) Natural -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar i -> a
varF Poly a (SysVar i) Natural
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero