{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Symbolic.Compiler (
    module ZkFold.Symbolic.Compiler.ArithmeticCircuit,
    compile,
    compileIO,
    compileForceOne,
    solder,
) where

import           Data.Aeson                                 (FromJSON, ToJSON, ToJSONKey)
import           Data.Binary                                (Binary)
import           Data.Function                              (const, (.))
import           Data.Functor                               (($>))
import           Data.Functor.Rep                           (Rep)
import           Data.Ord                                   (Ord)
import           Data.Proxy                                 (Proxy (..))
import           Data.Traversable                           (for)
import           GHC.Generics                               (Par1 (Par1))
import           Prelude                                    (FilePath, IO, Show (..), Traversable, putStrLn, return,
                                                             type (~), ($), (++))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Prelude                             (writeFileJSON)
import           ZkFold.Symbolic.Class                      (Arithmetic, Symbolic (..), fromCircuit2F)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit
import           ZkFold.Symbolic.Data.Bool                  (Bool (Bool))
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Input
import           ZkFold.Symbolic.MonadCircuit               (MonadCircuit (..))

{-
    ZkFold Symbolic compiler module dependency order:
    1. ZkFold.Symbolic.Compiler.ArithmeticCircuit.MerkleHash
    2. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
    3. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
    4. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance
    5. ZkFold.Symbolic.Compiler.ArithmeticCircuit
    6. ZkFold.Symbolic.Compiler
-}

forceOne :: (Symbolic c, Traversable f) => c f -> c f
forceOne :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Traversable f) =>
c f -> c f
forceOne c f
r = c f -> CircuitFun '[f] f c -> c f
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c f
r (\f i
fi -> f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
fi ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
i -> ClosedPoly i (BaseField c) -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) m () -> i -> m i
forall (f :: Type -> Type) a b. Functor f => f a -> b -> f b
$> i
i)

-- | Arithmetizes an argument by feeding an appropriate amount of inputs.
solder ::
    forall a c p f s .
    ( c ~ ArithmeticCircuit a p (Layout s)
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , SymbolicInput s
    , Context s ~ c
    , Symbolic c
    ) => f -> c (Layout f)
solder :: forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s.
(c ~ ArithmeticCircuit a p (Layout s), SymbolicData f,
 Context f ~ c, Support f ~ s, SymbolicInput s, Context s ~ c,
 Symbolic c) =>
f -> c (Layout f)
solder f
f = c (Layout f)
-> c Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Layout f, Par1] (Layout f) i m)
-> c (Layout f)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F (f -> Support f -> Context f (Layout f)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces f
f Support f
input) c Par1
Context s Par1
b ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Layout f, Par1] (Layout f) i m)
 -> c (Layout f))
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Layout f, Par1] (Layout f) i m)
-> c (Layout f)
forall a b. (a -> b) -> a -> b
$ \Layout f i
r (Par1 i
i) -> do
    ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
i)
    Layout f i -> m (Layout f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Layout f i
r
  where
    Bool Context s Par1
b = s -> Bool (Context s)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid s
Support f
input
    input :: Support f
input = forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore @(Support f) ((Support (Support f) -> Context (Support f) (Layout (Support f)))
 -> Support f)
-> (Support (Support f)
    -> Context (Support f) (Layout (Support f)))
-> Support f
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a p (Layout s) (Layout s)
-> Proxy (ArithmeticCircuit a p (Layout s))
-> ArithmeticCircuit a p (Layout s) (Layout s)
forall a b. a -> b -> a
const ArithmeticCircuit a p (Layout s) (Layout s)
forall (i :: Type -> Type) a (p :: Type -> Type).
Representable i =>
ArithmeticCircuit a p i i
idCircuit

-- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1.
compileForceOne ::
    forall a c p f s l y .
    ( c ~ ArithmeticCircuit a p l
    , Arithmetic a
    , Binary a
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , SymbolicInput s
    , Context s ~ c
    , Layout s ~ l
    , Binary (Rep p)
    , Binary (Rep l)
    , Ord (Rep l)
    , SymbolicData y
    , Context y ~ c
    , Support y ~ Proxy c
    , Layout f ~ Layout y
    , Traversable (Layout y)
    ) => f -> y
compileForceOne :: forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s
       (l :: Type -> Type) y.
(c ~ ArithmeticCircuit a p l, Arithmetic a, Binary a,
 SymbolicData f, Context f ~ c, Support f ~ s, SymbolicInput s,
 Context s ~ c, Layout s ~ l, Binary (Rep p), Binary (Rep l),
 Ord (Rep l), SymbolicData y, Context y ~ c, Support y ~ Proxy c,
 Layout f ~ Layout y, Traversable (Layout y)) =>
f -> y
compileForceOne = (Proxy (ArithmeticCircuit a p l)
 -> ArithmeticCircuit a p l (Layout y))
-> y
(Support y -> Context y (Layout y)) -> y
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore ((Proxy (ArithmeticCircuit a p l)
  -> ArithmeticCircuit a p l (Layout y))
 -> y)
-> (f
    -> Proxy (ArithmeticCircuit a p l)
    -> ArithmeticCircuit a p l (Layout y))
-> f
-> y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p l (Layout y)
-> Proxy (ArithmeticCircuit a p l)
-> ArithmeticCircuit a p l (Layout y)
forall a b. a -> b -> a
const (ArithmeticCircuit a p l (Layout y)
 -> Proxy (ArithmeticCircuit a p l)
 -> ArithmeticCircuit a p l (Layout y))
-> (f -> ArithmeticCircuit a p l (Layout y))
-> f
-> Proxy (ArithmeticCircuit a p l)
-> ArithmeticCircuit a p l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p l (Layout y)
-> ArithmeticCircuit a p l (Layout y)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize (ArithmeticCircuit a p l (Layout y)
 -> ArithmeticCircuit a p l (Layout y))
-> (f -> ArithmeticCircuit a p l (Layout y))
-> f
-> ArithmeticCircuit a p l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p l (Layout y)
-> ArithmeticCircuit a p l (Layout y)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Traversable f) =>
c f -> c f
forceOne (ArithmeticCircuit a p l (Layout y)
 -> ArithmeticCircuit a p l (Layout y))
-> (f -> ArithmeticCircuit a p l (Layout y))
-> f
-> ArithmeticCircuit a p l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s.
(c ~ ArithmeticCircuit a p (Layout s), SymbolicData f,
 Context f ~ c, Support f ~ s, SymbolicInput s, Context s ~ c,
 Symbolic c) =>
f -> c (Layout f)
solder @a

-- | Compiles function `f` into an arithmetic circuit.
compile ::
    forall a c p f s l y .
    ( c ~ ArithmeticCircuit a p l
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , SymbolicInput s
    , Context s ~ c
    , Layout s ~ l
    , SymbolicData y
    , Context y ~ c
    , Support y ~ Proxy c
    , Layout f ~ Layout y
    , Symbolic c
    ) => f -> y
compile :: forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s
       (l :: Type -> Type) y.
(c ~ ArithmeticCircuit a p l, SymbolicData f, Context f ~ c,
 Support f ~ s, SymbolicInput s, Context s ~ c, Layout s ~ l,
 SymbolicData y, Context y ~ c, Support y ~ Proxy c,
 Layout f ~ Layout y, Symbolic c) =>
f -> y
compile = (Proxy (ArithmeticCircuit a p l)
 -> ArithmeticCircuit a p l (Layout y))
-> y
(Support y -> Context y (Layout y)) -> y
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore ((Proxy (ArithmeticCircuit a p l)
  -> ArithmeticCircuit a p l (Layout y))
 -> y)
-> (f
    -> Proxy (ArithmeticCircuit a p l)
    -> ArithmeticCircuit a p l (Layout y))
-> f
-> y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p l (Layout y)
-> Proxy (ArithmeticCircuit a p l)
-> ArithmeticCircuit a p l (Layout y)
forall a b. a -> b -> a
const (ArithmeticCircuit a p l (Layout y)
 -> Proxy (ArithmeticCircuit a p l)
 -> ArithmeticCircuit a p l (Layout y))
-> (f -> ArithmeticCircuit a p l (Layout y))
-> f
-> Proxy (ArithmeticCircuit a p l)
-> ArithmeticCircuit a p l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p l (Layout y)
-> ArithmeticCircuit a p l (Layout y)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize (ArithmeticCircuit a p l (Layout y)
 -> ArithmeticCircuit a p l (Layout y))
-> (f -> ArithmeticCircuit a p l (Layout y))
-> f
-> ArithmeticCircuit a p l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s.
(c ~ ArithmeticCircuit a p (Layout s), SymbolicData f,
 Context f ~ c, Support f ~ s, SymbolicInput s, Context s ~ c,
 Symbolic c) =>
f -> c (Layout f)
solder @a

-- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file.
compileIO ::
    forall a c p f s l .
    ( c ~ ArithmeticCircuit a p l
    , FromJSON a
    , ToJSON a
    , ToJSONKey a
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , ToJSON (Layout f (Var a l))
    , SymbolicInput s
    , Context s ~ c
    , Layout s ~ l
    , FromJSON (Rep l)
    , ToJSON (Rep l)
    , Symbolic c
    ) => FilePath -> f -> IO ()
compileIO :: forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s
       (l :: Type -> Type).
(c ~ ArithmeticCircuit a p l, FromJSON a, ToJSON a, ToJSONKey a,
 SymbolicData f, Context f ~ c, Support f ~ s,
 ToJSON (Layout f (Var a l)), SymbolicInput s, Context s ~ c,
 Layout s ~ l, FromJSON (Rep l), ToJSON (Rep l), Symbolic c) =>
FilePath -> f -> IO ()
compileIO FilePath
scriptFile f
f = do
    let ac :: c (Layout f)
ac = ArithmeticCircuit a p l (Layout f)
-> ArithmeticCircuit a p l (Layout f)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize (forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s.
(c ~ ArithmeticCircuit a p (Layout s), SymbolicData f,
 Context f ~ c, Support f ~ s, SymbolicInput s, Context s ~ c,
 Symbolic c) =>
f -> c (Layout f)
solder @a f
f) :: c (Layout f)

    FilePath -> IO ()
putStrLn FilePath
"\nCompiling the script...\n"

    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of constraints: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a p l (Layout f) -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeN c (Layout f)
ArithmeticCircuit a p l (Layout f)
ac)
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of variables: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a p l (Layout f) -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeM c (Layout f)
ArithmeticCircuit a p l (Layout f)
ac)
    FilePath -> ArithmeticCircuit a p l (Layout f) -> IO ()
forall a. ToJSON a => FilePath -> a -> IO ()
writeFileJSON FilePath
scriptFile c (Layout f)
ArithmeticCircuit a p l (Layout f)
ac
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Script saved: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
scriptFile