{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Eq (
    Eq(..),
    elem
) where

import           Data.Bool                        (bool)
import           Data.Foldable                    (Foldable)
import           Data.Functor.Rep                 (Representable, mzipRep, mzipWithRep)
import           Data.Traversable                 (Traversable, for)
import           Prelude                          (return, ($))
import qualified Prelude                          as Haskell

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Data.Package
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool        (Bool (Bool), BoolType (..), all, any)
import           ZkFold.Symbolic.Data.Combinators (runInvert)
import           ZkFold.Symbolic.MonadCircuit

class Eq b a where
    infix 4 ==
    (==) :: a -> a -> b

    infix 4 /=
    (/=) :: a -> a -> b

elem :: (BoolType b, Eq b a, Foldable t) => a -> t a -> b
elem :: forall b a (t :: Type -> Type).
(BoolType b, Eq b a, Foldable t) =>
a -> t a -> b
elem a
x = (a -> b) -> t a -> b
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any (a -> a -> b
forall b a. Eq b a => a -> a -> b
== a
x)

instance Haskell.Eq a => Eq Haskell.Bool a where
    == :: a -> a -> Bool
(==) = a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: a -> a -> Bool
(/=) = a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)

instance (Symbolic c, Haskell.Eq (BaseField c), Representable f, Traversable f)
  => Eq (Bool c) (c f) where
    c f
x == :: c f -> c f -> Bool c
== c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell.== BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> m (f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return f i
isZeros
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

    c f
x /= :: c f -> c f -> Bool c
/= c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell./= BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
isZeros ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
isZ ->
                      ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
isZ)
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)