{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans     #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance where

import           Control.DeepSeq                                     (NFData)
import           Data.Aeson                                          hiding (Bool)
import           Data.Binary                                         (Binary)
import           Data.Bool                                           (bool)
import           Data.Functor.Rep                                    (Representable (..))
import           Data.Map                                            hiding (drop, foldl, foldl', foldr, map, null,
                                                                      splitAt, take, toList)
import           GHC.Generics                                        (Par1 (..))
import           Prelude                                             (Show, head, mempty, pure, return, show, ($), (++),
                                                                      (.), (<$>), (<))
import qualified Prelude                                             as Haskell
import           Test.QuickCheck                                     (Arbitrary (arbitrary), Gen, elements)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector                             (Vector, unsafeToVector)
import           ZkFold.Prelude                                      (genSubset, length)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Var
import           ZkFold.Symbolic.Data.FieldElement                   (FieldElement (..))
import           ZkFold.Symbolic.MonadCircuit

------------------------------------- Instances -------------------------------------

instance
  ( Arithmetic a
  , Arbitrary a
  , Binary a
  , Binary (Rep p)
  , Arbitrary (Rep i)
  , Binary (Rep i)
  , Haskell.Ord (Rep i)
  , NFData (Rep i)
  , Representable i
  , Haskell.Foldable i
  ) => Arbitrary (ArithmeticCircuit a p i Par1) where
    arbitrary :: Gen (ArithmeticCircuit a p i Par1)
arbitrary = do
        Var a i
outVar <- SysVar i -> Var a i
forall a (i :: Type -> Type). Semiring a => SysVar i -> Var a i
toVar (SysVar i -> Var a i) -> (Rep i -> SysVar i) -> Rep i -> Var a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar (Rep i -> Var a i) -> Gen (Rep i) -> Gen (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Rep i)
forall a. Arbitrary a => Gen a
arbitrary
        let ac :: ArithmeticCircuit a p i Par1
ac = ArithmeticCircuit a p i U1
forall a. Monoid a => a
mempty {acOutput = Par1 outVar}
        FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement (FieldElement (ArithmeticCircuit a p i)
 -> ArithmeticCircuit a p i Par1)
-> Gen (FieldElement (ArithmeticCircuit a p i))
-> Gen (ArithmeticCircuit a p i Par1)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
forall a (p :: Type -> Type) (i :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i), NFData (Rep i), Representable i, Foldable i) =>
FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' (ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement ArithmeticCircuit a p i Par1
ac) Natural
10

instance
  ( Arithmetic a
  , Arbitrary a
  , Binary a
  , Binary (Rep p)
  , Arbitrary (Rep i)
  , Binary (Rep i)
  , Haskell.Ord (Rep i)
  , NFData (Rep i)
  , Representable i
  , Haskell.Foldable i
  , KnownNat l
  ) => Arbitrary (ArithmeticCircuit a p i (Vector l)) where
    arbitrary :: Gen (ArithmeticCircuit a p i (Vector l))
arbitrary = do
        ArithmeticCircuit a p i Par1
ac <- forall a. Arbitrary a => Gen a
arbitrary @(ArithmeticCircuit a p i Par1)
        Vector l (SysVar i)
o  <- [SysVar i] -> Vector l (SysVar i)
forall (size :: Natural) a. [a] -> Vector size a
unsafeToVector ([SysVar i] -> Vector l (SysVar i))
-> Gen [SysVar i] -> Gen (Vector l (SysVar i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> [SysVar i] -> Gen [SysVar i]
forall a. Natural -> [a] -> Gen [a]
genSubset (forall (n :: Natural). KnownNat n => Natural
value @l) (ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars ArithmeticCircuit a p i Par1
ac)
        ArithmeticCircuit a p i (Vector l)
-> Gen (ArithmeticCircuit a p i (Vector l))
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ArithmeticCircuit a p i Par1
ac {acOutput = toVar <$> o}

arbitrary' ::
  forall a p i .
  (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Haskell.Ord (Rep i), NFData (Rep i)) =>
  (Representable i, Haskell.Foldable i) =>
  FieldElement (ArithmeticCircuit a p i) -> Natural ->
  Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' :: forall a (p :: Type -> Type) (i :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i), NFData (Rep i), Representable i, Foldable i) =>
FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' FieldElement (ArithmeticCircuit a p i)
ac Natural
0 = FieldElement (ArithmeticCircuit a p i)
-> Gen (FieldElement (ArithmeticCircuit a p i))
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (FieldElement (ArithmeticCircuit a p i)
 -> Gen (FieldElement (ArithmeticCircuit a p i)))
-> FieldElement (ArithmeticCircuit a p i)
-> Gen (FieldElement (ArithmeticCircuit a p i))
forall a b. (a -> b) -> a -> b
$ FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> Bool
-> FieldElement (ArithmeticCircuit a p i)
forall a. a -> a -> Bool -> a
bool FieldElement (ArithmeticCircuit a p i)
ac (FieldElement (ArithmeticCircuit a p i)
newF FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. MultiplicativeSemigroup a => a -> a -> a
* FieldElement (ArithmeticCircuit a p i)
newF) (Natural
numOfVars Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
< Natural
2)
  where
    vars :: [SysVar i]
vars = ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars (ArithmeticCircuit a p i Par1 -> [SysVar i])
-> ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a b. (a -> b) -> a -> b
$ FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac
    numOfVars :: Natural
numOfVars = [SysVar i] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [SysVar i]
vars
    newF :: FieldElement (ArithmeticCircuit a p i)
newF = ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac) { acOutput = pure (toVar $ head vars)}
arbitrary' FieldElement (ArithmeticCircuit a p i)
ac Natural
iter = do
    let vars :: [SysVar i]
vars = ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac)
    SysVar i
li <- [SysVar i] -> Gen (SysVar i)
forall a. [a] -> Gen a
elements [SysVar i]
vars
    SysVar i
ri <- [SysVar i] -> Gen (SysVar i)
forall a. [a] -> Gen a
elements [SysVar i]
vars
    let (FieldElement (ArithmeticCircuit a p i)
l, FieldElement (ArithmeticCircuit a p i)
r) = ( ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac) { acOutput = pure (toVar li)}
                 , ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac) { acOutput = pure (toVar ri)})
    let c :: FieldElement (ArithmeticCircuit a p i)
c = ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement (FieldElement (ArithmeticCircuit a p i)
 -> ArithmeticCircuit a p i Par1)
-> FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall a b. (a -> b) -> a -> b
$ FieldElement (ArithmeticCircuit a p i)
-> BaseField (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type).
Symbolic c =>
FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint FieldElement (ArithmeticCircuit a p i)
ac (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
10)) { acOutput = pure (toVar li)}

    FieldElement (ArithmeticCircuit a p i)
ac' <- [FieldElement (ArithmeticCircuit a p i)]
-> Gen (FieldElement (ArithmeticCircuit a p i))
forall a. [a] -> Gen a
elements [
        FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. AdditiveSemigroup a => a -> a -> a
+ FieldElement (ArithmeticCircuit a p i)
r
        , FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. MultiplicativeSemigroup a => a -> a -> a
* FieldElement (ArithmeticCircuit a p i)
r
        , FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. AdditiveGroup a => a -> a -> a
- FieldElement (ArithmeticCircuit a p i)
r
        , FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. Field a => a -> a -> a
// FieldElement (ArithmeticCircuit a p i)
r
        , FieldElement (ArithmeticCircuit a p i)
c
        ]
    FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
forall a (p :: Type -> Type) (i :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i), NFData (Rep i), Representable i, Foldable i) =>
FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' FieldElement (ArithmeticCircuit a p i)
ac' (Natural
iter Natural -> Natural -> Natural
-! Natural
1)


createRangeConstraint :: Symbolic c => FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint :: forall (c :: (Type -> Type) -> Type).
Symbolic c =>
FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint (FieldElement c Par1
x) BaseField c
a = c Par1 -> FieldElement c
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (c Par1 -> FieldElement c) -> c Par1 -> FieldElement c
forall a b. (a -> b) -> a -> b
$ c Par1 -> CircuitFun '[Par1] Par1 c -> c Par1
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c Par1
x (\ (Par1 i
v) ->  i -> Par1 i
forall p. p -> Par1 p
Par1 (i -> Par1 i) -> m i -> m (Par1 i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> i -> BaseField c -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
var -> a -> m var
solve i
v BaseField c
a)
  where
    solve :: MonadCircuit var a w m => var -> a -> m var
    solve :: forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
var -> a -> m var
solve var
v a
b = do
      var
v' <- ClosedPoly var a -> m var
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (x -> (var -> x) -> x
forall a b. a -> b -> a
Haskell.const x
forall a. AdditiveMonoid a => a
zero)
      var -> a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
var -> a -> m ()
rangeConstraint var
v' a
b
      var -> m var
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return var
v

-- TODO: make it more readable
instance (Show a, Show (o (Var a i)), Show (Var a i), Show (Rep i), Haskell.Ord (Rep i)) => Show (ArithmeticCircuit a p i o) where
    show :: ArithmeticCircuit a p i o -> String
show ArithmeticCircuit a p i o
r = String
"ArithmeticCircuit { acSystem = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Map ByteString (Constraint a i) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
r)
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n, acRange = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ MonoidalMap a (Set (SysVar i)) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
acRange ArithmeticCircuit a p i o
r)
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n, acOutput = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ o (Var a i) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o
r)
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" }"

-- TODO: add witness generation info to the JSON object
instance (ToJSON a, ToJSON (o (Var a i)), ToJSONKey a, FromJSONKey (Var a i), ToJSON (Rep i)) => ToJSON (ArithmeticCircuit a p i o) where
    toJSON :: ArithmeticCircuit a p i o -> Value
toJSON ArithmeticCircuit a p i o
r = [Pair] -> Value
object
        [
            Key
"system" Key -> Map ByteString (Constraint a i) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
r,
            Key
"range"  Key -> MonoidalMap a (Set (SysVar i)) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
acRange ArithmeticCircuit a p i o
r,
            Key
"output" Key -> o (Var a i) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o
r
        ]

-- TODO: properly restore the witness generation function
instance (FromJSON a, FromJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey a, Haskell.Ord a, Haskell.Ord (Rep i), FromJSON (Rep i)) => FromJSON (ArithmeticCircuit a p i o) where
    parseJSON :: Value -> Parser (ArithmeticCircuit a p i o)
parseJSON =
        String
-> (Object -> Parser (ArithmeticCircuit a p i o))
-> Value
-> Parser (ArithmeticCircuit a p i o)
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"ArithmeticCircuit" ((Object -> Parser (ArithmeticCircuit a p i o))
 -> Value -> Parser (ArithmeticCircuit a p i o))
-> (Object -> Parser (ArithmeticCircuit a p i o))
-> Value
-> Parser (ArithmeticCircuit a p i o)
forall a b. (a -> b) -> a -> b
$ \Object
v -> do
            Map ByteString (Constraint a i)
acSystem   <- Object
v Object -> Key -> Parser (Map ByteString (Constraint a i))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"system"
            MonoidalMap a (Set (SysVar i))
acRange    <- Object
v Object -> Key -> Parser (MonoidalMap a (Set (SysVar i)))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"range"
            o (Var a i)
acOutput   <- Object
v Object -> Key -> Parser (o (Var a i))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"output"
            let acWitness :: Map k a
acWitness = Map k a
forall k a. Map k a
empty
                acFold :: Map k a
acFold    = Map k a
forall k a. Map k a
empty
            ArithmeticCircuit a p i o -> Parser (ArithmeticCircuit a p i o)
forall a. a -> Parser a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ArithmeticCircuit{o (Var a i)
Map ByteString (Constraint a i)
Map ByteString (CircuitWitness a p i)
Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
MonoidalMap a (Set (SysVar i))
forall k a. Map k a
acOutput :: o (Var a i)
acSystem :: Map ByteString (Constraint a i)
acRange :: MonoidalMap a (Set (SysVar i))
acSystem :: Map ByteString (Constraint a i)
acRange :: MonoidalMap a (Set (SysVar i))
acOutput :: o (Var a i)
acWitness :: forall k a. Map k a
acFold :: forall k a. Map k a
acWitness :: Map ByteString (CircuitWitness a p i)
acFold :: Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
..}