{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans     #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance where

import           Control.Monad                                             (foldM, guard, replicateM)
import           Data.Aeson                                                hiding (Bool)
import           Data.Map                                                  hiding (drop, foldl, foldl', foldr, map,
                                                                            null, splitAt, take)
import           Data.Traversable                                          (for)
import qualified Data.Zip                                                  as Z
import           Numeric.Natural                                           (Natural)
import           Prelude                                                   (Integer, const, id, mempty, pure, return,
                                                                            show, type (~), ($), (++), (.), (<$>),
                                                                            (>>=))
import qualified Prelude                                                   as Haskell
import           System.Random                                             (mkStdGen)
import           Test.QuickCheck                                           (Arbitrary (..))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Data.Vector                                   as V
import           ZkFold.Base.Data.Vector                                   (Vector (..))
import           ZkFold.Prelude                                            (length)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators    (embedAll, embedV, expansion, foldCircuit,
                                                                            horner, invertC, isZeroC)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal       hiding (constraint)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint (MonadBlueprint (..), circuit, circuitN)
import           ZkFold.Symbolic.Compiler.Arithmetizable                   (SymbolicData (..))
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.Conditional
import           ZkFold.Symbolic.Data.DiscreteField
import           ZkFold.Symbolic.Data.Eq

------------------------------------- Instances -------------------------------------

instance Arithmetic a => SymbolicData a (ArithmeticCircuit n a) where
    type TypeSize a (ArithmeticCircuit n a) = n

    pieces :: ArithmeticCircuit n a
-> ArithmeticCircuit (TypeSize a (ArithmeticCircuit n a)) a
pieces = ArithmeticCircuit n a -> ArithmeticCircuit n a
ArithmeticCircuit n a
-> ArithmeticCircuit (TypeSize a (ArithmeticCircuit n a)) a
forall a. a -> a
id

    restore :: Circuit a
-> Vector (TypeSize a (ArithmeticCircuit n a)) Natural
-> ArithmeticCircuit n a
restore = Circuit a -> Vector n Natural -> ArithmeticCircuit n a
Circuit a
-> Vector (TypeSize a (ArithmeticCircuit n a)) Natural
-> ArithmeticCircuit n a
forall (n :: Natural) a.
Circuit a -> Vector n Natural -> ArithmeticCircuit n a
ArithmeticCircuit

-- TODO: I had to add these constraints and I don't like them
instance
    ( KnownNat (n * Order a)
    , KnownNat (Log2 ((n * Order a) - 1) + 1)
    ) => Finite (ArithmeticCircuit n a) where
    type Order (ArithmeticCircuit n a) = n * Order a

instance Arithmetic a => AdditiveSemigroup (ArithmeticCircuit n a) where
    ArithmeticCircuit n a
r1 + :: ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
+ ArithmeticCircuit n a
r2 = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector n i))
 -> ArithmeticCircuit n a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector n i))
-> ArithmeticCircuit n a
forall a b. (a -> b) -> a -> b
$ do
        Vector n i
is <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r1
        Vector n i
js <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r2
        Vector n (i, i) -> ((i, i) -> m i) -> m (Vector n i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (Vector n i -> Vector n i -> Vector n (i, i)
forall a b. Vector n a -> Vector n b -> Vector n (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
Z.zip Vector n i
is Vector n i
js) (((i, i) -> m i) -> m (Vector n i))
-> ((i, i) -> m i) -> m (Vector n i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
j)

instance (Arithmetic a, Scale c a) => Scale c (ArithmeticCircuit n a) where
    scale :: c -> ArithmeticCircuit n a -> ArithmeticCircuit n a
scale c
c ArithmeticCircuit n a
r = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector n i))
 -> ArithmeticCircuit n a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector n i))
-> ArithmeticCircuit n a
forall a b. (a -> b) -> a -> b
$ do
        Vector n i
is <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r
        Vector n i -> (i -> m i) -> m (Vector n i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Vector n i
is ((i -> m i) -> m (Vector n i)) -> (i -> m i) -> m (Vector n i)
forall a b. (a -> b) -> a -> b
$ \i
i -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> (c
c c -> a -> a
forall b a. Scale b a => b -> a -> a
`scale` a
forall a. MultiplicativeMonoid a => a
one :: a) a -> x -> x
forall b a. Scale b a => b -> a -> a
`scale` i -> x
x i
i)

instance (Arithmetic a, KnownNat n) => AdditiveMonoid (ArithmeticCircuit n a) where
    zero :: ArithmeticCircuit n a
zero = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector n i))
 -> ArithmeticCircuit n a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector n i))
-> ArithmeticCircuit n a
forall a b. (a -> b) -> a -> b
$ [i] -> Vector n i
forall (size :: Natural) a. [a] -> Vector size a
Vector ([i] -> Vector n i) -> m [i] -> m (Vector n i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m i -> m [i]
forall (m :: Type -> Type) a. Applicative m => Int -> m a -> m [a]
replicateM (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
value @n) (ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (x -> (i -> x) -> x
forall a b. a -> b -> a
const x
forall a. AdditiveMonoid a => a
zero))

instance (Arithmetic a, KnownNat n) => AdditiveGroup (ArithmeticCircuit n a) where
    negate :: ArithmeticCircuit n a -> ArithmeticCircuit n a
negate ArithmeticCircuit n a
r = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector n i))
 -> ArithmeticCircuit n a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector n i))
-> ArithmeticCircuit n a
forall a b. (a -> b) -> a -> b
$ do
        Vector n i
is <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r
        Vector n i -> (i -> m i) -> m (Vector n i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Vector n i
is ((i -> m i) -> m (Vector n i)) -> (i -> m i) -> m (Vector n i)
forall a b. (a -> b) -> a -> b
$ \i
i -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> x -> x
forall a. AdditiveGroup a => a -> a
negate (i -> x
x i
i))

    ArithmeticCircuit n a
r1 - :: ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
- ArithmeticCircuit n a
r2 = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector n i))
 -> ArithmeticCircuit n a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector n i))
-> ArithmeticCircuit n a
forall a b. (a -> b) -> a -> b
$ do
        Vector n i
is <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r1
        Vector n i
js <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r2
        Vector n (i, i) -> ((i, i) -> m i) -> m (Vector n i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (Vector n i -> Vector n i -> Vector n (i, i)
forall a b. Vector n a -> Vector n b -> Vector n (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
Z.zip Vector n i
is Vector n i
js) (((i, i) -> m i) -> m (Vector n i))
-> ((i, i) -> m i) -> m (Vector n i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
j)

instance Arithmetic a => MultiplicativeSemigroup (ArithmeticCircuit n a) where
    ArithmeticCircuit n a
r1 * :: ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
* ArithmeticCircuit n a
r2 = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector n i))
 -> ArithmeticCircuit n a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector n i))
-> ArithmeticCircuit n a
forall a b. (a -> b) -> a -> b
$ do
        Vector n i
is <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r1
        Vector n i
js <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
r2
        Vector n (i, i) -> ((i, i) -> m i) -> m (Vector n i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (Vector n i -> Vector n i -> Vector n (i, i)
forall a b. Vector n a -> Vector n b -> Vector n (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
Z.zip Vector n i
is Vector n i
js) (((i, i) -> m i) -> m (Vector n i))
-> ((i, i) -> m i) -> m (Vector n i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j)

instance (Arithmetic a, KnownNat n) => Exponent (ArithmeticCircuit n a) Natural where
    ^ :: ArithmeticCircuit n a -> Natural -> ArithmeticCircuit n a
(^) = ArithmeticCircuit n a -> Natural -> ArithmeticCircuit n a
forall a. MultiplicativeMonoid a => a -> Natural -> a
natPow

instance (Arithmetic a, KnownNat n) => MultiplicativeMonoid (ArithmeticCircuit n a) where
    one :: ArithmeticCircuit n a
one = Vector n a -> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
Vector n a -> ArithmeticCircuit n a
embedV (a -> Vector n a
forall a. a -> Vector n a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure a
forall a. MultiplicativeMonoid a => a
one)

-- TODO: The constant will be replicated in all outputs. Is this the desired behaviour?
instance (Arithmetic a, FromConstant b a, KnownNat n) => FromConstant b (ArithmeticCircuit n a) where
    fromConstant :: b -> ArithmeticCircuit n a
fromConstant b
c = a -> ArithmeticCircuit n a
forall a (n :: Natural).
(Arithmetic a, KnownNat n) =>
a -> ArithmeticCircuit n a
embedAll (b -> a
forall a b. FromConstant a b => a -> b
fromConstant b
c)

instance (Arithmetic a, KnownNat n) => Semiring (ArithmeticCircuit n a)

instance (Arithmetic a, KnownNat n) => Ring (ArithmeticCircuit n a)

instance (Arithmetic a, KnownNat n) => Exponent (ArithmeticCircuit n a) Integer where
    ^ :: ArithmeticCircuit n a -> Integer -> ArithmeticCircuit n a
(^) = ArithmeticCircuit n a -> Integer -> ArithmeticCircuit n a
forall a. Field a => a -> Integer -> a
intPowF

instance (Arithmetic a, KnownNat n) => Field (ArithmeticCircuit n a) where
    finv :: ArithmeticCircuit n a -> ArithmeticCircuit n a
finv = ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
ArithmeticCircuit n a -> ArithmeticCircuit n a
invertC
    rootOfUnity :: Natural -> Maybe (ArithmeticCircuit n a)
rootOfUnity Natural
n = a -> ArithmeticCircuit n a
forall a (n :: Natural).
(Arithmetic a, KnownNat n) =>
a -> ArithmeticCircuit n a
embedAll (a -> ArithmeticCircuit n a)
-> Maybe a -> Maybe (ArithmeticCircuit n a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> Maybe a
forall a. Field a => Natural -> Maybe a
rootOfUnity Natural
n

-- TODO: The only implementation that seems to make sense is when there is only one output.
-- It is true?
--
-- Anyway, @binaryExpansion@ of an arithmetic circuit will return copies of the same circuit with different outputs.
-- The whole point of this refactor was to avoid this.
--
-- Ideally, we want to return another ArithmeticCircuit with a number of outputs corresponding to the number of bits.
-- This does not align well with the type of @binaryExpansion@
instance (Arithmetic a, bits ~ NumberOfBits a) => BinaryExpansion (ArithmeticCircuit 1 a) (ArithmeticCircuit bits a) where
    binaryExpansion :: ArithmeticCircuit 1 a -> ArithmeticCircuit bits a
binaryExpansion ArithmeticCircuit 1 a
r = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector bits i))
-> ArithmeticCircuit bits a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN ((forall i (m :: Type -> Type).
  MonadBlueprint i a m =>
  m (Vector bits i))
 -> ArithmeticCircuit bits a)
-> (forall i (m :: Type -> Type).
    MonadBlueprint i a m =>
    m (Vector bits i))
-> ArithmeticCircuit bits a
forall a b. (a -> b) -> a -> b
$ do
        Vector 1 i
output <- ArithmeticCircuit 1 a -> m (Vector 1 i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit 1 a
r
        [i]
bits <- Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion (forall a. KnownNat (NumberOfBits a) => Natural
numberOfBits @a) (i -> m [i]) -> (Vector 1 i -> i) -> Vector 1 i -> m [i]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector 1 i -> i
forall a. Vector 1 a -> a
V.item (Vector 1 i -> m [i]) -> Vector 1 i -> m [i]
forall a b. (a -> b) -> a -> b
$ Vector 1 i
output
        Vector bits i -> m (Vector bits i)
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Vector bits i -> m (Vector bits i))
-> Vector bits i -> m (Vector bits i)
forall a b. (a -> b) -> a -> b
$ [i] -> Vector bits i
forall (size :: Natural) a. [a] -> Vector size a
V.unsafeToVector [i]
bits
    fromBinary :: ArithmeticCircuit bits a -> ArithmeticCircuit 1 a
fromBinary ArithmeticCircuit bits a
bits = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit 1 a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit 1 a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit 1 a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit 1 a
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit bits a -> m (Vector bits i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit bits a
bits m (Vector bits i) -> (Vector bits i -> m i) -> m i
forall a b. m a -> (a -> m b) -> m b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner ([i] -> m i) -> (Vector bits i -> [i]) -> Vector bits i -> m i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector bits i -> [i]
forall (size :: Natural) a. Vector size a -> [a]
V.fromVector

instance (Arithmetic a, KnownNat n) => DiscreteField' (ArithmeticCircuit n a) where
    equal :: ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
equal ArithmeticCircuit n a
r1 ArithmeticCircuit n a
r2 = ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
ArithmeticCircuit n a -> ArithmeticCircuit n a
isZeroC (ArithmeticCircuit n a
r1 ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a. AdditiveGroup a => a -> a -> a
- ArithmeticCircuit n a
r2)

instance Arithmetic a => TrichotomyField (ArithmeticCircuit 1 a) where
    trichotomy :: ArithmeticCircuit 1 a
-> ArithmeticCircuit 1 a -> ArithmeticCircuit 1 a
trichotomy ArithmeticCircuit 1 a
r1 ArithmeticCircuit 1 a
r2 =
        let
            bits1 :: ArithmeticCircuit (NumberOfBits a) a
bits1 = ArithmeticCircuit 1 a -> ArithmeticCircuit (NumberOfBits a) a
forall a b. BinaryExpansion a b => a -> b
binaryExpansion ArithmeticCircuit 1 a
r1
            bits2 :: ArithmeticCircuit (NumberOfBits a) a
bits2 = ArithmeticCircuit 1 a -> ArithmeticCircuit (NumberOfBits a) a
forall a b. BinaryExpansion a b => a -> b
binaryExpansion ArithmeticCircuit 1 a
r2
            -- zip pairs of bits in {0,1} to orderings in {-1,0,1}
            comparedBits :: ArithmeticCircuit (NumberOfBits a) a
comparedBits = ArithmeticCircuit (NumberOfBits a) a
bits1 ArithmeticCircuit (NumberOfBits a) a
-> ArithmeticCircuit (NumberOfBits a) a
-> ArithmeticCircuit (NumberOfBits a) a
forall a. AdditiveGroup a => a -> a -> a
- ArithmeticCircuit (NumberOfBits a) a
bits2
            -- least significant bit first,
            -- reverse lexicographical ordering
            reverseLexicographical :: i -> i -> m i
reverseLexicographical i
x i
y = ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ \i -> x
p -> i -> x
p i
y x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
p i
y x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (i -> x
p i
y x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
p i
x) x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
p i
x
        in
            (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 i -> i -> m i)
-> ArithmeticCircuit (NumberOfBits a) a -> ArithmeticCircuit 1 a
forall (n :: Natural) a.
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 i -> i -> m i)
-> ArithmeticCircuit n a -> ArithmeticCircuit 1 a
foldCircuit i -> i -> m i
forall {i} {a} {m :: Type -> Type}.
MonadBlueprint i a m =>
i -> i -> m i
forall i (m :: Type -> Type). MonadBlueprint i a m => i -> i -> m i
reverseLexicographical ArithmeticCircuit (NumberOfBits a) a
comparedBits

instance Arithmetic a => SymbolicData a (Bool (ArithmeticCircuit n a)) where
    type TypeSize a (Bool (ArithmeticCircuit n a)) = n
    pieces :: Bool (ArithmeticCircuit n a)
-> ArithmeticCircuit (TypeSize a (Bool (ArithmeticCircuit n a))) a
pieces (Bool ArithmeticCircuit n a
b) = ArithmeticCircuit n a
-> ArithmeticCircuit (TypeSize a (ArithmeticCircuit n a)) a
forall a x.
SymbolicData a x =>
x -> ArithmeticCircuit (TypeSize a x) a
pieces ArithmeticCircuit n a
b
    restore :: Circuit a
-> Vector (TypeSize a (Bool (ArithmeticCircuit n a))) Natural
-> Bool (ArithmeticCircuit n a)
restore Circuit a
c = ArithmeticCircuit n a -> Bool (ArithmeticCircuit n a)
forall x. x -> Bool x
Bool (ArithmeticCircuit n a -> Bool (ArithmeticCircuit n a))
-> (Vector n Natural -> ArithmeticCircuit n a)
-> Vector n Natural
-> Bool (ArithmeticCircuit n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
Haskell.. Circuit a
-> Vector (TypeSize a (ArithmeticCircuit n a)) Natural
-> ArithmeticCircuit n a
forall a x.
SymbolicData a x =>
Circuit a -> Vector (TypeSize a x) Natural -> x
restore Circuit a
c

instance (Arithmetic a, KnownNat n, 1 <= n) => DiscreteField (Bool (ArithmeticCircuit 1 a)) (ArithmeticCircuit n a) where
    isZero :: ArithmeticCircuit n a -> Bool (ArithmeticCircuit 1 a)
isZero ArithmeticCircuit n a
x = ArithmeticCircuit 1 a -> Bool (ArithmeticCircuit 1 a)
forall x. x -> Bool x
Bool (ArithmeticCircuit 1 a -> Bool (ArithmeticCircuit 1 a))
-> ArithmeticCircuit 1 a -> Bool (ArithmeticCircuit 1 a)
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit 1 a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit 1 a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit 1 a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit 1 a
forall a b. (a -> b) -> a -> b
$ do
        Vector n i
bools <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit (ArithmeticCircuit n a -> m (Vector n i))
-> ArithmeticCircuit n a -> m (Vector n i)
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a (n :: Natural).
Arithmetic a =>
ArithmeticCircuit n a -> ArithmeticCircuit n a
isZeroC ArithmeticCircuit n a
x
        (i -> i -> m i) -> i -> Vector (n - 1) i -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\i
i i
j -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
p -> i -> x
p i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
p i
j)) (Vector n i -> i
forall (size :: Natural) a. Vector size a -> a
V.head Vector n i
bools) (Vector n i -> Vector (n - 1) i
forall (size :: Natural) a. Vector size a -> Vector (size - 1) a
V.tail Vector n i
bools)

instance (Arithmetic a, KnownNat n, 1 <= n) => Eq (Bool (ArithmeticCircuit 1 a)) (ArithmeticCircuit n a) where
    ArithmeticCircuit n a
x == :: ArithmeticCircuit n a
-> ArithmeticCircuit n a -> Bool (ArithmeticCircuit 1 a)
== ArithmeticCircuit n a
y = ArithmeticCircuit n a -> Bool (ArithmeticCircuit 1 a)
forall b a. DiscreteField b a => a -> b
isZero (ArithmeticCircuit n a
x ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a. AdditiveGroup a => a -> a -> a
- ArithmeticCircuit n a
y)
    ArithmeticCircuit n a
x /= :: ArithmeticCircuit n a
-> ArithmeticCircuit n a -> Bool (ArithmeticCircuit 1 a)
/= ArithmeticCircuit n a
y = Bool (ArithmeticCircuit 1 a) -> Bool (ArithmeticCircuit 1 a)
forall b. BoolType b => b -> b
not (Bool (ArithmeticCircuit 1 a) -> Bool (ArithmeticCircuit 1 a))
-> Bool (ArithmeticCircuit 1 a) -> Bool (ArithmeticCircuit 1 a)
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit n a -> Bool (ArithmeticCircuit 1 a)
forall b a. DiscreteField b a => a -> b
isZero (ArithmeticCircuit n a
x ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a. AdditiveGroup a => a -> a -> a
- ArithmeticCircuit n a
y)

instance {-# OVERLAPPING #-} (SymbolicData a x, n ~ TypeSize a x, KnownNat n) => Conditional (Bool (ArithmeticCircuit 1 a)) x where
    bool :: x -> x -> Bool (ArithmeticCircuit 1 a) -> x
bool x
brFalse x
brTrue (Bool ArithmeticCircuit 1 a
b) = Circuit a -> Vector (TypeSize a x) Natural -> x
forall a x.
SymbolicData a x =>
Circuit a -> Vector (TypeSize a x) Natural -> x
restore Circuit a
c Vector (TypeSize a x) Natural
o
        where
            f' :: ArithmeticCircuit (TypeSize a x) a
f' = x -> ArithmeticCircuit (TypeSize a x) a
forall a x.
SymbolicData a x =>
x -> ArithmeticCircuit (TypeSize a x) a
pieces x
brFalse
            t' :: ArithmeticCircuit (TypeSize a x) a
t' = x -> ArithmeticCircuit (TypeSize a x) a
forall a x.
SymbolicData a x =>
x -> ArithmeticCircuit (TypeSize a x) a
pieces x
brTrue
            ArithmeticCircuit Circuit a
c Vector (TypeSize a x) Natural
o = (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector (TypeSize a x) i))
-> ArithmeticCircuit (TypeSize a x) a
forall a (n :: Natural).
Arithmetic a =>
(forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Vector n i))
-> ArithmeticCircuit n a
circuitN m (Vector n i)
m (Vector (TypeSize a x) i)
forall i (m :: Type -> Type).
MonadBlueprint i a m =>
m (Vector n i)
forall i (m :: Type -> Type).
MonadBlueprint i a m =>
m (Vector (TypeSize a x) i)
solve

            solve :: forall i m . MonadBlueprint i a m => m (Vector n i)
            solve :: forall i (m :: Type -> Type).
MonadBlueprint i a m =>
m (Vector n i)
solve = do
                Vector n i
ts <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
ArithmeticCircuit (TypeSize a x) a
t'
                Vector n i
fs <- ArithmeticCircuit n a -> m (Vector n i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit n a
ArithmeticCircuit (TypeSize a x) a
f'
                i
bs <- Vector 1 i -> i
forall a. Vector 1 a -> a
V.item (Vector 1 i -> i) -> m (Vector 1 i) -> m i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit 1 a -> m (Vector 1 i)
forall (n :: Natural). ArithmeticCircuit n a -> m (Vector n i)
forall i a (m :: Type -> Type) (n :: Natural).
MonadBlueprint i a m =>
ArithmeticCircuit n a -> m (Vector n i)
runCircuit ArithmeticCircuit 1 a
b
                (i -> i -> m i) -> Vector n i -> Vector n i -> m (Vector n i)
forall (n :: Natural) (m :: Type -> Type) a b c.
Applicative m =>
(a -> b -> m c) -> Vector n a -> Vector n b -> m (Vector n c)
V.zipWithM (\i
x i
y -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ \i -> x
p -> i -> x
p i
bs x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (i -> x
p i
x x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
p i
y) x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
p i
y) Vector n i
ts Vector n i
fs

-- TODO: make a proper implementation of Arbitrary
instance (Arithmetic a, KnownNat n) => Arbitrary (ArithmeticCircuit n a) where
    arbitrary :: Gen (ArithmeticCircuit n a)
arbitrary = ArithmeticCircuit n a -> Gen (ArithmeticCircuit n a)
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (ArithmeticCircuit n a -> Gen (ArithmeticCircuit n a))
-> ArithmeticCircuit n a -> Gen (ArithmeticCircuit n a)
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit n a
c1 ArithmeticCircuit n a
-> ArithmeticCircuit n a -> ArithmeticCircuit n a
forall a. MultiplicativeSemigroup a => a -> a -> a
* ArithmeticCircuit n a
c2
        where
            c1, c2 :: ArithmeticCircuit n a
            c1 :: ArithmeticCircuit n a
c1 = ArithmeticCircuit { acCircuit :: Circuit a
acCircuit = Circuit a
forall a. Monoid a => a
mempty, acOutput :: Vector n Natural
acOutput = Natural -> Vector n Natural
forall a. a -> Vector n a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Natural
1}
            c2 :: ArithmeticCircuit n a
c2 = ArithmeticCircuit { acCircuit :: Circuit a
acCircuit = Circuit a
forall a. Monoid a => a
mempty, acOutput :: Vector n Natural
acOutput = Natural -> Vector n Natural
forall a. a -> Vector n a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Natural
2}


-- TODO: make it more readable
instance (FiniteField a, Haskell.Eq a, Haskell.Show a) => Haskell.Show (ArithmeticCircuit n a) where
    show :: ArithmeticCircuit n a -> String
show (ArithmeticCircuit Circuit a
r Vector n Natural
o) = String
"ArithmeticCircuit { acSystem = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Map Natural (Constraint a) -> String
forall a. Show a => a -> String
show (Circuit a -> Map Natural (Constraint a)
forall a. Circuit a -> Map Natural (Constraint a)
acSystem Circuit a
r) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", acInput = "
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Natural] -> String
forall a. Show a => a -> String
show (Circuit a -> [Natural]
forall a. Circuit a -> [Natural]
acInput Circuit a
r) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", acOutput = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Vector n Natural -> String
forall a. Show a => a -> String
show Vector n Natural
o String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", acVarOrder = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Map (Natural, Natural) Natural -> String
forall a. Show a => a -> String
show (Circuit a -> Map (Natural, Natural) Natural
forall a. Circuit a -> Map (Natural, Natural) Natural
acVarOrder Circuit a
r) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" }"

-- TODO: add witness generation info to the JSON object
instance ToJSON a => ToJSON (ArithmeticCircuit n a) where
    toJSON :: ArithmeticCircuit n a -> Value
toJSON (ArithmeticCircuit Circuit a
r Vector n Natural
o) = [Pair] -> Value
object
        [
            Key
"system" Key -> Map Natural (Constraint a) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Circuit a -> Map Natural (Constraint a)
forall a. Circuit a -> Map Natural (Constraint a)
acSystem Circuit a
r,
            Key
"input"  Key -> [Natural] -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Circuit a -> [Natural]
forall a. Circuit a -> [Natural]
acInput Circuit a
r,
            Key
"output" Key -> [Natural] -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Vector n Natural -> [Natural]
forall (size :: Natural) a. Vector size a -> [a]
V.fromVector Vector n Natural
o,
            Key
"order"  Key -> Map (Natural, Natural) Natural -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= Circuit a -> Map (Natural, Natural) Natural
forall a. Circuit a -> Map (Natural, Natural) Natural
acVarOrder Circuit a
r
        ]

-- TODO: properly restore the witness generation function
-- TODO: Check that there are exactly N outputs
instance (FromJSON a, KnownNat n) => FromJSON (ArithmeticCircuit n a) where
    parseJSON :: Value -> Parser (ArithmeticCircuit n a)
parseJSON =
        String
-> (Object -> Parser (ArithmeticCircuit n a))
-> Value
-> Parser (ArithmeticCircuit n a)
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"ArithmeticCircuit" ((Object -> Parser (ArithmeticCircuit n a))
 -> Value -> Parser (ArithmeticCircuit n a))
-> (Object -> Parser (ArithmeticCircuit n a))
-> Value
-> Parser (ArithmeticCircuit n a)
forall a b. (a -> b) -> a -> b
$ \Object
v -> do
            Map Natural (Constraint a)
acSystem <- Object
v Object -> Key -> Parser (Map Natural (Constraint a))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"system"
            [Natural]
acInput <- Object
v Object -> Key -> Parser [Natural]
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"input"
            let acWitness :: b -> Map k a
acWitness = Map k a -> b -> Map k a
forall a b. a -> b -> a
const Map k a
forall k a. Map k a
empty
            [Natural]
outs <- Object
v Object -> Key -> Parser [Natural]
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"output"
            Bool -> Parser ()
forall (f :: Type -> Type). Alternative f => Bool -> f ()
guard (Object -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length Object
v Natural -> Natural -> Bool
forall b a. Eq b a => a -> a -> b
== (forall (n :: Natural). KnownNat n => Natural
value @n))
            let acOutput :: Vector n Natural
acOutput = [Natural] -> Vector n Natural
forall (size :: Natural) a. [a] -> Vector size a
Vector [Natural]
outs
            Map (Natural, Natural) Natural
acVarOrder <- Object
v Object -> Key -> Parser (Map (Natural, Natural) Natural)
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"order"
            let acRNG :: StdGen
acRNG = Int -> StdGen
mkStdGen Int
0
                acCircuit :: Circuit a
acCircuit = Circuit{[Natural]
StdGen
Map Natural (Constraint a)
Map (Natural, Natural) Natural
Map Natural a -> Map Natural a
forall {b} {k} {a}. b -> Map k a
acSystem :: Map Natural (Constraint a)
acInput :: [Natural]
acVarOrder :: Map (Natural, Natural) Natural
acSystem :: Map Natural (Constraint a)
acInput :: [Natural]
acWitness :: forall {b} {k} {a}. b -> Map k a
acVarOrder :: Map (Natural, Natural) Natural
acRNG :: StdGen
acWitness :: Map Natural a -> Map Natural a
acRNG :: StdGen
..}
            ArithmeticCircuit n a -> Parser (ArithmeticCircuit n a)
forall a. a -> Parser a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ArithmeticCircuit{Vector n Natural
Circuit a
acCircuit :: Circuit a
acOutput :: Vector n Natural
acOutput :: Vector n Natural
acCircuit :: Circuit a
..}