{-# LANGUAGE AllowAmbiguousTypes #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
        ArithmeticCircuit,
        Constraint,
        -- high-level functions
        applyArgs,
        optimize,
        -- low-level functions
        eval,
        forceZero,
        -- information about the system
        acSizeN,
        acSizeM,
        acSystem,
        acValue,
        acPrint,
        -- Variable mapping functions
        mapVarArithmeticCircuit,
        mapVarWitness,
        -- Arithmetization type fields
        acWitness,
        acVarOrder,
        acOutput,
        -- Testing functions
        checkCircuit,
        checkClosedCircuit
    ) where

import           Control.Monad.State                                 (execState)
import           Data.Map                                            hiding (drop, foldl, foldr, map, null, splitAt,
                                                                      take)
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Num (..), drop, length, product, splitAt,
                                                                      sum, take, (!!), (^))
import           Test.QuickCheck                                     (Arbitrary, Property, conjoin, property, vector,
                                                                      withMaxSuccess, (===))
import           Text.Pretty.Simple                                  (pPrint)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Polynomials.Multivariate        (evalMapM, evalPolynomial)
import           ZkFold.Prelude                                      (length)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance ()
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint,
                                                                      apply, eval, forceZero)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map

--------------------------------- High-level functions --------------------------------

-- TODO: make this work for different input types.
applyArgs :: forall a . ArithmeticCircuit a -> [a] -> ArithmeticCircuit a
applyArgs :: forall a. ArithmeticCircuit a -> [a] -> ArithmeticCircuit a
applyArgs ArithmeticCircuit a
r [a]
args = State (ArithmeticCircuit a) ()
-> ArithmeticCircuit a -> ArithmeticCircuit a
forall s a. State s a -> s -> s
execState ([a] -> State (ArithmeticCircuit a) ()
forall a. [a] -> State (ArithmeticCircuit a) ()
apply [a]
args) ArithmeticCircuit a
r

-- | Optimizes the constraint system.
--
-- TODO: Implement nontrivial optimizations.
optimize :: ArithmeticCircuit a -> ArithmeticCircuit a
optimize :: forall a. ArithmeticCircuit a -> ArithmeticCircuit a
optimize = ArithmeticCircuit a -> ArithmeticCircuit a
forall a. a -> a
id

----------------------------------- Information -----------------------------------

-- | Calculates the number of constraints in the system.
acSizeN :: ArithmeticCircuit a -> Natural
acSizeN :: forall a. ArithmeticCircuit a -> Natural
acSizeN = Map Natural (Constraint a) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map Natural (Constraint a) -> Natural)
-> (ArithmeticCircuit a -> Map Natural (Constraint a))
-> ArithmeticCircuit a
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a -> Map Natural (Constraint a)
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem

-- | Calculates the number of variables in the system.
-- The constant `1` is not counted.
acSizeM :: ArithmeticCircuit a -> Natural
acSizeM :: forall a. ArithmeticCircuit a -> Natural
acSizeM = Map (Natural, Natural) Natural -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map (Natural, Natural) Natural -> Natural)
-> (ArithmeticCircuit a -> Map (Natural, Natural) Natural)
-> ArithmeticCircuit a
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a -> Map (Natural, Natural) Natural
forall a. ArithmeticCircuit a -> Map (Natural, Natural) Natural
acVarOrder

acValue :: ArithmeticCircuit a -> a
acValue :: forall a. ArithmeticCircuit a -> a
acValue ArithmeticCircuit a
r = ArithmeticCircuit a -> Map Natural a -> a
forall a. ArithmeticCircuit a -> Map Natural a -> a
eval ArithmeticCircuit a
r Map Natural a
forall a. Monoid a => a
mempty

-- | Prints the constraint system, the witness, and the output.
--
-- TODO: Move this elsewhere (?)
-- TODO: Check that all arguments have been applied.
acPrint :: forall a . Show a => ArithmeticCircuit a -> IO ()
acPrint :: forall a. Show a => ArithmeticCircuit a -> IO ()
acPrint ArithmeticCircuit a
r = do
    let m :: [Constraint a]
m = Map Natural (Constraint a) -> [Constraint a]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a -> Map Natural (Constraint a)
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem ArithmeticCircuit a
r)
        i :: [Natural]
i = ArithmeticCircuit a -> [Natural]
forall a. ArithmeticCircuit a -> [Natural]
acInput ArithmeticCircuit a
r
        w :: Map Natural a
w = ArithmeticCircuit a -> Map Natural a -> Map Natural a
forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness ArithmeticCircuit a
r Map Natural a
forall k a. Map k a
empty
        o :: Natural
o = ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput ArithmeticCircuit a
r
        v :: a
v = ArithmeticCircuit a -> a
forall a. ArithmeticCircuit a -> a
acValue ArithmeticCircuit a
r
        vo :: Map (Natural, Natural) Natural
vo = ArithmeticCircuit a -> Map (Natural, Natural) Natural
forall a. ArithmeticCircuit a -> Map (Natural, Natural) Natural
acVarOrder ArithmeticCircuit a
r
    String -> IO ()
putStr String
"System size: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acSizeN ArithmeticCircuit a
r
    String -> IO ()
putStr String
"Variable size: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acSizeM ArithmeticCircuit a
r
    String -> IO ()
putStr String
"Matrices: "
    [Constraint a] -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint [Constraint a]
m
    String -> IO ()
putStr String
"Input: "
    [Natural] -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint [Natural]
i
    String -> IO ()
putStr String
"Witness: "
    Map Natural a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint Map Natural a
w
    String -> IO ()
putStr String
"Variable order: "
    Map (Natural, Natural) Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint Map (Natural, Natural) Natural
vo
    String -> IO ()
putStr String
"Output: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint Natural
o
    String -> IO ()
putStr String
"Value: "
    a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint a
v

---------------------------------- Testing -------------------------------------

checkClosedCircuit :: (Arithmetic a, FromConstant a a, Scale a a, Show a) => ArithmeticCircuit a -> Property
checkClosedCircuit :: forall a.
(Arithmetic a, FromConstant a a, Scale a a, Show a) =>
ArithmeticCircuit a -> Property
checkClosedCircuit ArithmeticCircuit a
r = Int -> Property -> Property
forall prop. Testable prop => Int -> prop -> Property
withMaxSuccess Int
1 (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
-> Property
testPoly P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p | P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p <- Map
  Natural
  (P a
     Natural
     Natural
     (Map Natural Natural)
     [(a, M Natural Natural (Map Natural Natural))])
-> [P a
      Natural
      Natural
      (Map Natural Natural)
      [(a, M Natural Natural (Map Natural Natural))]]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a
-> Map
     Natural
     (P a
        Natural
        Natural
        (Map Natural Natural)
        [(a, M Natural Natural (Map Natural Natural))])
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem ArithmeticCircuit a
r) ]
    where
        w :: Map Natural a
w = ArithmeticCircuit a -> Map Natural a -> Map Natural a
forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness ArithmeticCircuit a
r Map Natural a
forall k a. Map k a
empty
        testPoly :: P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
-> Property
testPoly P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p = ((Natural -> a) -> M Natural Natural (Map Natural Natural) -> a)
-> (Natural -> a)
-> P a
     Natural
     Natural
     (Map Natural Natural)
     [(a, M Natural Natural (Map Natural Natural))]
-> a
forall {k1} c i (j :: k1) b m.
Algebra c b =>
((i -> b) -> M i j m -> b)
-> (i -> b) -> P c i j m [(c, M i j m)] -> b
evalPolynomial (Natural -> a) -> M Natural Natural (Map Natural Natural) -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> M i j (Map i j) -> b
evalMapM (Map Natural a
w Map Natural a -> Natural -> a
forall k a. Ord k => Map k a -> k -> a
!) P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero

checkCircuit :: (Arbitrary a, Arithmetic a, FromConstant a a, Scale a a, Show a) => ArithmeticCircuit a -> Property
checkCircuit :: forall a.
(Arbitrary a, Arithmetic a, FromConstant a a, Scale a a, Show a) =>
ArithmeticCircuit a -> Property
checkCircuit ArithmeticCircuit a
r = [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Gen Property -> Property
forall prop. Testable prop => prop -> Property
property (P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
-> Gen Property
testPoly P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p) | P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p <- Map
  Natural
  (P a
     Natural
     Natural
     (Map Natural Natural)
     [(a, M Natural Natural (Map Natural Natural))])
-> [P a
      Natural
      Natural
      (Map Natural Natural)
      [(a, M Natural Natural (Map Natural Natural))]]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a
-> Map
     Natural
     (P a
        Natural
        Natural
        (Map Natural Natural)
        [(a, M Natural Natural (Map Natural Natural))])
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem ArithmeticCircuit a
r) ]
    where
        testPoly :: P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
-> Gen Property
testPoly P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p = do
            [a]
ins <- Int -> Gen [a]
forall a. Arbitrary a => Int -> Gen [a]
vector (Int -> Gen [a]) -> (Natural -> Int) -> Natural -> Gen [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Natural -> Gen [a]) -> Natural -> Gen [a]
forall a b. (a -> b) -> a -> b
$ [Natural] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (ArithmeticCircuit a -> [Natural]
forall a. ArithmeticCircuit a -> [Natural]
acInput ArithmeticCircuit a
r)
            let w :: Map Natural a
w = ArithmeticCircuit a -> Map Natural a -> Map Natural a
forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness ArithmeticCircuit a
r (Map Natural a -> Map Natural a)
-> ([(Natural, a)] -> Map Natural a)
-> [(Natural, a)]
-> Map Natural a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Natural, a)] -> Map Natural a
forall k a. Ord k => [(k, a)] -> Map k a
fromList ([(Natural, a)] -> Map Natural a)
-> [(Natural, a)] -> Map Natural a
forall a b. (a -> b) -> a -> b
$ [Natural] -> [a] -> [(Natural, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (ArithmeticCircuit a -> [Natural]
forall a. ArithmeticCircuit a -> [Natural]
acInput ArithmeticCircuit a
r) [a]
ins
            Property -> Gen Property
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Property -> Gen Property) -> Property -> Gen Property
forall a b. (a -> b) -> a -> b
$ ((Natural -> a) -> M Natural Natural (Map Natural Natural) -> a)
-> (Natural -> a)
-> P a
     Natural
     Natural
     (Map Natural Natural)
     [(a, M Natural Natural (Map Natural Natural))]
-> a
forall {k1} c i (j :: k1) b m.
Algebra c b =>
((i -> b) -> M i j m -> b)
-> (i -> b) -> P c i j m [(c, M i j m)] -> b
evalPolynomial (Natural -> a) -> M Natural Natural (Map Natural Natural) -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> M i j (Map i j) -> b
evalMapM (Map Natural a
w Map Natural a -> Natural -> a
forall k a. Ord k => Map k a -> k -> a
!) P a
  Natural
  Natural
  (Map Natural Natural)
  [(a, M Natural Natural (Map Natural Natural))]
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero