{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Input (
    SymbolicInput (..)
) where

import           Control.DeepSeq                    (NFData)
import           Control.Monad.Representable.Reader (Rep)
import           Data.Functor.Rep                   (Representable)
import           Data.Ord                           (Ord)
import           Data.Type.Equality                 (type (~))
import           Data.Typeable                      (Proxy (..))
import           GHC.Generics                       (Par1 (..))
import           GHC.TypeLits                       (KnownNat)
import           Prelude                            (foldl, ($))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Data.ByteString        (Binary)
import           ZkFold.Base.Data.Vector            (Vector, fromVector)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators
import           ZkFold.Symbolic.MonadCircuit


-- | A class for Symbolic input.
class
    ( SymbolicData d
    , Support d ~ Proxy (Context d)
    , Representable (Layout d)
    , Binary (Rep (Layout d))
    , Ord (Rep (Layout d))
    , NFData (Rep (Layout d))
    ) => SymbolicInput d where
    isValid :: d -> Bool (Context d)


instance Symbolic c => SymbolicInput (Bool c) where
  isValid :: Bool c -> Bool (Context (Bool c))
isValid (Bool c Par1
b) = Context (Bool c) Par1 -> Bool (Context (Bool c))
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (Context (Bool c) Par1 -> Bool (Context (Bool c)))
-> Context (Bool c) Par1 -> Bool (Context (Bool c))
forall a b. (a -> b) -> a -> b
$ Context (Bool c) Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (Context (Bool c)))
       (WitnessField (Context (Bool c)))
       m) =>
    FunBody '[Par1] Par1 i m)
-> Context (Bool c) Par1
forall (f :: Type -> Type) (g :: Type -> Type).
Context (Bool c) f
-> CircuitFun '[f] g (Context (Bool c)) -> Context (Bool c) g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c Par1
Context (Bool c) Par1
b ((forall {i} {m :: Type -> Type}.
  (NFData i,
   MonadCircuit
     i
     (BaseField (Context (Bool c)))
     (WitnessField (Context (Bool c)))
     m) =>
  FunBody '[Par1] Par1 i m)
 -> Context (Bool c) Par1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (Context (Bool c)))
       (WitnessField (Context (Bool c)))
       m) =>
    FunBody '[Par1] Par1 i m)
-> Context (Bool c) Par1
forall a b. (a -> b) -> a -> b
$
      \(Par1 i
v) -> do
        i
u <- ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
x -> i -> x
x i
v x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
v))
        Par1 i -> m (Par1 i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i)
isZero (Par1 i -> m (Par1 i)) -> Par1 i -> m (Par1 i)
forall a b. (a -> b) -> a -> b
$ i -> Par1 i
forall p. p -> Par1 p
Par1 i
u


instance
  ( Symbolic c
  , Binary (Rep f)
  , Ord (Rep f)
  , NFData (Rep f)
  , Representable f) => SymbolicInput (c f) where
  isValid :: c f -> Bool (Context (c f))
isValid c f
_ = Bool c
Bool (Context (c f))
forall b. BoolType b => b
true


instance Symbolic c => SymbolicInput (Proxy c) where
  isValid :: Proxy c -> Bool (Context (Proxy c))
isValid Proxy c
_ = Bool c
Bool (Context (Proxy c))
forall b. BoolType b => b
true

instance (
    Symbolic (Context x)
    , Context x ~ Context y
    , SymbolicInput x
    , SymbolicInput y
    ) => SymbolicInput (x, y) where
  isValid :: (x, y) -> Bool (Context (x, y))
isValid (x
l, y
r) = x -> Bool (Context x)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid x
l Bool (Context y) -> Bool (Context y) -> Bool (Context y)
forall b. BoolType b => b -> b -> b
&& y -> Bool (Context y)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid y
r

instance (
    Symbolic (Context x)
    , Context x ~ Context y
    , Context y ~ Context z
    , SymbolicInput x
    , SymbolicInput y
    , SymbolicInput z
    ) => SymbolicInput (x, y, z) where
  isValid :: (x, y, z) -> Bool (Context (x, y, z))
isValid (x
l, y
m, z
r) = x -> Bool (Context x)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid x
l Bool (Context z) -> Bool (Context z) -> Bool (Context z)
forall b. BoolType b => b -> b -> b
&& y -> Bool (Context y)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid y
m Bool (Context z) -> Bool (Context z) -> Bool (Context z)
forall b. BoolType b => b -> b -> b
&& z -> Bool (Context z)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid z
r

instance (
  Symbolic (Context x)
  , KnownNat n
  , SymbolicInput x
  ) => SymbolicInput (Vector n x) where
  isValid :: Vector n x -> Bool (Context (Vector n x))
isValid Vector n x
v = (Bool (Context (Vector n x)) -> x -> Bool (Context (Vector n x)))
-> Bool (Context (Vector n x))
-> [x]
-> Bool (Context (Vector n x))
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Bool (Context (Vector n x))
l x
r -> Bool (Context x)
Bool (Context (Vector n x))
l Bool (Context x) -> Bool (Context x) -> Bool (Context x)
forall b. BoolType b => b -> b -> b
&& x -> Bool (Context x)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid x
r) Bool (Context x)
Bool (Context (Vector n x))
forall b. BoolType b => b
true ([x] -> Bool (Context (Vector n x)))
-> [x] -> Bool (Context (Vector n x))
forall a b. (a -> b) -> a -> b
$ Vector n x -> [x]
forall (size :: Nat) a. Vector size a -> [a]
fromVector Vector n x
v