{-# LANGUAGE DerivingVia   #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Symbolic.MonadCircuit where

import           Control.Applicative             (Applicative)
import           Control.Monad                   (Monad (return))
import           Data.Eq                         (Eq)
import           Data.Function                   (id)
import           Data.Functor                    (Functor)
import           Data.Functor.Identity           (Identity (..))
import           Data.Ord                        (Ord)
import           Data.Type.Equality              (type (~))
import           Numeric.Natural                 (Natural)

import           ZkFold.Base.Algebra.Basic.Class

-- | A @'WitnessField'@ should support all algebraic operations
-- used inside an arithmetic circuit.
type WitnessField n a = ( FiniteField a, ToConstant a, Const a ~ n
                        , FromConstant n a, SemiEuclidean n)

-- | A type of witness builders. @var@ is a type of variables, @a@ is a base field.
--
-- A function is a witness builder if, given an arbitrary field of witnesses @x@
-- over @a@ and a function mapping known variables to their witnesses,
-- it computes the new witness in @x@.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.
type Witness var a = forall x n . (Algebra a x, WitnessField n x) => (var -> x) -> x

-- | A type of polynomial expressions.
-- @var@ is a type of variables, @a@ is a base field.
--
-- A function is a polynomial expression if, given an arbitrary algebra @x@ over
-- @a@ and a function mapping known variables to their witnesses, it computes a
-- new value in that algebra.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.
type ClosedPoly var a = forall x . Algebra a x => (var -> x) -> x

-- | A type of constraints for new variables.
-- @var@ is a type of variables, @a@ is a base field.
--
-- A function is a constraint for a new variable if, given an arbitrary algebra
-- @x@ over @a@, a function mapping known variables to their witnesses in that
-- algebra and a new variable, it computes the value of a constraint polynomial
-- in that algebra.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.
type NewConstraint var a = forall x . Algebra a x => (var -> x) -> var -> x

-- | A monadic DSL for constructing arithmetic circuits.
-- @var@ is a type of variables, @a@ is a base field
-- and @m@ is a monad for constructing the circuit.
--
-- DSL provides the following guarantees:
--
-- * Constraints never reference undefined variables;
-- * Variables with equal witnesses are reused as much as possible;
-- * Variables with different witnesses are different;
-- * There is an order in which witnesses can be generated.
--
-- However, DSL does NOT provide the following guarantees (yet):
--
-- * That provided witnesses satisfy the provided constraints. To check this,
--   you can use 'ZkFold.Symbolic.Compiler.ArithmeticCircuit.checkCircuit'.
-- * That introduced constraints are supported by the zk-SNARK utilized for later proving.
class (Monad m, FromConstant a var) => MonadCircuit var a m | m -> var, m -> a where
  -- | Creates new variable from witness.
  --
  -- NOTE: this does not add any constraints to the system,
  -- use 'rangeConstraint' or 'constraint' to add them.
  unconstrained :: Witness var a -> m var

  -- | Adds new polynomial constraint to the system.
  -- E.g., @'constraint' (\\x -> x i)@ forces variable @var@ to be zero.
  --
  -- NOTE: it is not checked (yet) whether provided constraint is in
  -- appropriate form for zkSNARK in use.
  constraint :: ClosedPoly var a -> m ()

  -- | Adds new range constraint to the system.
  -- E.g., @'rangeConstraint' var B@ forces variable @var@ to be in range \([0; B]\).
  rangeConstraint :: var -> a -> m ()

  -- | Creates new variable given a polynomial witness
  -- AND adds a corresponding polynomial constraint.
  --
  -- E.g., @'newAssigned' (\\x -> x i + x j)@ creates new variable @k@
  -- whose value is equal to \(x_i + x_j\)
  -- and a constraint \(x_i + x_j - x_k = 0\).
  --
  -- NOTE: this adds a polynomial constraint to the system.
  --
  -- NOTE: is is not checked (yet) whether the corresponding constraint is in
  -- appropriate form for zkSNARK in use.
  newAssigned :: ClosedPoly var a -> m var
  newAssigned ClosedPoly var a
p = NewConstraint var a -> Witness var a -> m var
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
NewConstraint var a -> Witness var a -> m var
newConstrained (\var -> x
x var
var -> (var -> x) -> x
ClosedPoly var a
p var -> x
x x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- var -> x
x var
var) (var -> x) -> x
ClosedPoly var a
Witness var a
p

-- | Creates new variable from witness constrained with an inclusive upper bound.
-- E.g., @'newRanged' b (\\x -> x var - one)@ creates new variable whose value
-- is equal to @x var - one@ and which is expected to be in range @[0..b]@.
--
-- NOTE: this adds a range constraint to the system.
newRanged :: MonadCircuit var a m => a -> Witness var a -> m var
newRanged :: forall var a (m :: Type -> Type).
MonadCircuit var a m =>
a -> Witness var a -> m var
newRanged a
upperBound Witness var a
witness = do
  var
v <- Witness var a -> m var
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
Witness var a -> m var
unconstrained (var -> x) -> x
Witness var a
witness
  var -> a -> m ()
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
var -> a -> m ()
rangeConstraint var
v a
upperBound
  var -> m var
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return var
v

-- | Creates new variable from witness constrained by a polynomial.
-- E.g., @'newConstrained' (\\x i -> x i * (x i - one)) (\\x -> x j - one)@
-- creates new variable whose value is equal to @x j - one@ and which is
-- expected to be a root of the polynomial @x i * (x i - one)@.
--
-- NOTE: this adds a polynomial constraint to the system.
--
-- NOTE: it is not checked (yet) whether provided constraint is in
-- appropriate form for zkSNARK in use.
newConstrained :: MonadCircuit var a m => NewConstraint var a -> Witness var a -> m var
newConstrained :: forall var a (m :: Type -> Type).
MonadCircuit var a m =>
NewConstraint var a -> Witness var a -> m var
newConstrained NewConstraint var a
poly Witness var a
witness = do
  var
v <- Witness var a -> m var
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
Witness var a -> m var
unconstrained (var -> x) -> x
Witness var a
witness
  ClosedPoly var a -> m ()
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
ClosedPoly var a -> m ()
constraint ((var -> x) -> var -> x
NewConstraint var a
`poly` var
v)
  var -> m var
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return var
v

-- | Field of witnesses with decidable equality and ordering
-- is called an ``arithmetic'' field.
type Arithmetic a = (WitnessField Natural a, Eq a, Ord a)

-- | An example implementation of a @'MonadCircuit'@ which computes witnesses
-- immediately and drops the constraints.
newtype Witnesses n a x = Witnesses { forall {k} {k} (n :: k) (a :: k) x. Witnesses n a x -> x
runWitnesses :: x }
  deriving ((forall a b. (a -> b) -> Witnesses n a a -> Witnesses n a b)
-> (forall a b. a -> Witnesses n a b -> Witnesses n a a)
-> Functor (Witnesses n a)
forall k (n :: k) k (a :: k) a b.
a -> Witnesses n a b -> Witnesses n a a
forall k (n :: k) k (a :: k) a b.
(a -> b) -> Witnesses n a a -> Witnesses n a b
forall a b. a -> Witnesses n a b -> Witnesses n a a
forall a b. (a -> b) -> Witnesses n a a -> Witnesses n a b
forall (f :: Type -> Type).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall k (n :: k) k (a :: k) a b.
(a -> b) -> Witnesses n a a -> Witnesses n a b
fmap :: forall a b. (a -> b) -> Witnesses n a a -> Witnesses n a b
$c<$ :: forall k (n :: k) k (a :: k) a b.
a -> Witnesses n a b -> Witnesses n a a
<$ :: forall a b. a -> Witnesses n a b -> Witnesses n a a
Functor, Functor (Witnesses n a)
Functor (Witnesses n a) =>
(forall a. a -> Witnesses n a a)
-> (forall a b.
    Witnesses n a (a -> b) -> Witnesses n a a -> Witnesses n a b)
-> (forall a b c.
    (a -> b -> c)
    -> Witnesses n a a -> Witnesses n a b -> Witnesses n a c)
-> (forall a b.
    Witnesses n a a -> Witnesses n a b -> Witnesses n a b)
-> (forall a b.
    Witnesses n a a -> Witnesses n a b -> Witnesses n a a)
-> Applicative (Witnesses n a)
forall a. a -> Witnesses n a a
forall k (n :: k) k (a :: k). Functor (Witnesses n a)
forall k (n :: k) k (a :: k) a. a -> Witnesses n a a
forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> Witnesses n a b -> Witnesses n a a
forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> Witnesses n a b -> Witnesses n a b
forall k (n :: k) k (a :: k) a b.
Witnesses n a (a -> b) -> Witnesses n a a -> Witnesses n a b
forall k (n :: k) k (a :: k) a b c.
(a -> b -> c)
-> Witnesses n a a -> Witnesses n a b -> Witnesses n a c
forall a b. Witnesses n a a -> Witnesses n a b -> Witnesses n a a
forall a b. Witnesses n a a -> Witnesses n a b -> Witnesses n a b
forall a b.
Witnesses n a (a -> b) -> Witnesses n a a -> Witnesses n a b
forall a b c.
(a -> b -> c)
-> Witnesses n a a -> Witnesses n a b -> Witnesses n a c
forall (f :: Type -> Type).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall k (n :: k) k (a :: k) a. a -> Witnesses n a a
pure :: forall a. a -> Witnesses n a a
$c<*> :: forall k (n :: k) k (a :: k) a b.
Witnesses n a (a -> b) -> Witnesses n a a -> Witnesses n a b
<*> :: forall a b.
Witnesses n a (a -> b) -> Witnesses n a a -> Witnesses n a b
$cliftA2 :: forall k (n :: k) k (a :: k) a b c.
(a -> b -> c)
-> Witnesses n a a -> Witnesses n a b -> Witnesses n a c
liftA2 :: forall a b c.
(a -> b -> c)
-> Witnesses n a a -> Witnesses n a b -> Witnesses n a c
$c*> :: forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> Witnesses n a b -> Witnesses n a b
*> :: forall a b. Witnesses n a a -> Witnesses n a b -> Witnesses n a b
$c<* :: forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> Witnesses n a b -> Witnesses n a a
<* :: forall a b. Witnesses n a a -> Witnesses n a b -> Witnesses n a a
Applicative, Applicative (Witnesses n a)
Applicative (Witnesses n a) =>
(forall a b.
 Witnesses n a a -> (a -> Witnesses n a b) -> Witnesses n a b)
-> (forall a b.
    Witnesses n a a -> Witnesses n a b -> Witnesses n a b)
-> (forall a. a -> Witnesses n a a)
-> Monad (Witnesses n a)
forall a. a -> Witnesses n a a
forall k (n :: k) k (a :: k). Applicative (Witnesses n a)
forall k (n :: k) k (a :: k) a. a -> Witnesses n a a
forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> Witnesses n a b -> Witnesses n a b
forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> (a -> Witnesses n a b) -> Witnesses n a b
forall a b. Witnesses n a a -> Witnesses n a b -> Witnesses n a b
forall a b.
Witnesses n a a -> (a -> Witnesses n a b) -> Witnesses n a b
forall (m :: Type -> Type).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> (a -> Witnesses n a b) -> Witnesses n a b
>>= :: forall a b.
Witnesses n a a -> (a -> Witnesses n a b) -> Witnesses n a b
$c>> :: forall k (n :: k) k (a :: k) a b.
Witnesses n a a -> Witnesses n a b -> Witnesses n a b
>> :: forall a b. Witnesses n a a -> Witnesses n a b -> Witnesses n a b
$creturn :: forall k (n :: k) k (a :: k) a. a -> Witnesses n a a
return :: forall a. a -> Witnesses n a a
Monad) via Identity

instance WitnessField n a => MonadCircuit a a (Witnesses n a) where
  unconstrained :: Witness a a -> Witnesses n a a
unconstrained Witness a a
w = a -> Witnesses n a a
forall a. a -> Witnesses n a a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((a -> a) -> a
Witness a a
w a -> a
forall a. a -> a
id)
  constraint :: ClosedPoly a a -> Witnesses n a ()
constraint ClosedPoly a a
_ = () -> Witnesses n a ()
forall a. a -> Witnesses n a a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
  rangeConstraint :: a -> a -> Witnesses n a ()
rangeConstraint a
_ a
_ = () -> Witnesses n a ()
forall a. a -> Witnesses n a a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()