{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Symbolic.Compiler (
    module ZkFold.Symbolic.Compiler.ArithmeticCircuit,
    compile,
    compileIO,
    compileForceOne,
    solder,
) where

import           Data.Aeson                                 (FromJSON, ToJSON)
import           Data.Binary                                (Binary)
import           Data.Function                              (const, (.))
import           Data.Functor                               (($>))
import           Data.Functor.Rep                           (Rep, Representable)
import           Data.Ord                                   (Ord)
import           Data.Proxy                                 (Proxy (..))
import           Data.Traversable                           (for)
import           GHC.Generics                               (Par1 (Par1))
import           Prelude                                    (FilePath, IO, Monoid (mempty), Show (..), Traversable,
                                                             putStrLn, return, type (~), ($), (++))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Prelude                             (writeFileJSON)
import           ZkFold.Symbolic.Class                      (Arithmetic, Symbolic (..), fromCircuit2F)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit
import           ZkFold.Symbolic.Data.Bool                  (Bool (Bool))
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Input
import           ZkFold.Symbolic.MonadCircuit               (MonadCircuit (..))

{-
    ZkFold Symbolic compiler module dependency order:
    1. ZkFold.Symbolic.Compiler.ArithmeticCircuit.MerkleHash
    2. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
    3. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
    4. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance
    5. ZkFold.Symbolic.Compiler.ArithmeticCircuit
    6. ZkFold.Symbolic.Compiler
-}

forceOne :: (Symbolic c, Traversable f) => c f -> c f
forceOne :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Traversable f) =>
c f -> c f
forceOne c f
r = c f -> CircuitFun f f (BaseField c) -> c f
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun f g (BaseField c) -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun f g (BaseField c) -> c g
fromCircuitF c f
r (\f i
fi -> f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
fi ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
i -> ClosedPoly i (BaseField c) -> m ()
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) m () -> i -> m i
forall (f :: Type -> Type) a b. Functor f => f a -> b -> f b
$> i
i)

-- | Arithmetizes an argument by feeding an appropriate amount of inputs.
solder ::
    forall a c f s .
    ( c ~ ArithmeticCircuit a (Layout s)
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , SymbolicInput s
    , Context s ~ c
    , Symbolic c
    ) => f -> c (Layout f)
solder :: forall a (c :: (Type -> Type) -> Type) f s.
(c ~ ArithmeticCircuit a (Layout s), SymbolicData f, Context f ~ c,
 Support f ~ s, SymbolicInput s, Context s ~ c, Symbolic c) =>
f -> c (Layout f)
solder f
f = c (Layout f)
-> c Par1
-> (forall {i} {m :: Type -> Type}.
    MonadCircuit i (BaseField c) m =>
    Layout f i -> Par1 i -> m (Layout f i))
-> c (Layout f)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f
-> c g
-> (forall i (m :: Type -> Type).
    MonadCircuit i (BaseField c) m =>
    f i -> g i -> m (h i))
-> c h
fromCircuit2F (f -> Support f -> Context f (Layout f)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces f
f Support f
input) c Par1
Context s Par1
b ((forall {i} {m :: Type -> Type}.
  MonadCircuit i (BaseField c) m =>
  Layout f i -> Par1 i -> m (Layout f i))
 -> c (Layout f))
-> (forall {i} {m :: Type -> Type}.
    MonadCircuit i (BaseField c) m =>
    Layout f i -> Par1 i -> m (Layout f i))
-> c (Layout f)
forall a b. (a -> b) -> a -> b
$ \Layout f i
r (Par1 i
i) -> do
    ClosedPoly i a -> m ()
forall var a (m :: Type -> Type).
MonadCircuit var a m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
i)
    Layout f i -> m (Layout f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Layout f i
r
  where
    Bool Context s Par1
b = s -> Bool (Context s)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid s
Support f
input
    input :: Support f
input = forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore @(Support f) ((Support (Support f) -> Context (Support f) (Layout (Support f)))
 -> Support f)
-> (Support (Support f)
    -> Context (Support f) (Layout (Support f)))
-> Support f
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a (Layout s) (Layout s)
-> Proxy (ArithmeticCircuit a (Layout s))
-> ArithmeticCircuit a (Layout s) (Layout s)
forall a b. a -> b -> a
const ArithmeticCircuit a (Layout s) U1
forall a. Monoid a => a
mempty { acOutput = acInput }

-- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1.
compileForceOne ::
    forall a c f s l y .
    ( c ~ ArithmeticCircuit a l
    , Arithmetic a
    , Binary a
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , SymbolicInput s
    , Context s ~ c
    , Layout s ~ l
    , Representable l
    , Binary (Rep l)
    , Ord (Rep l)
    , SymbolicData y
    , Context y ~ c
    , Support y ~ Proxy c
    , Layout f ~ Layout y
    , Traversable (Layout y)
    ) => f -> y
compileForceOne :: forall a (c :: (Type -> Type) -> Type) f s (l :: Type -> Type) y.
(c ~ ArithmeticCircuit a l, Arithmetic a, Binary a, SymbolicData f,
 Context f ~ c, Support f ~ s, SymbolicInput s, Context s ~ c,
 Layout s ~ l, Representable l, Binary (Rep l), Ord (Rep l),
 SymbolicData y, Context y ~ c, Support y ~ Proxy c,
 Layout f ~ Layout y, Traversable (Layout y)) =>
f -> y
compileForceOne = (Proxy (ArithmeticCircuit a l) -> ArithmeticCircuit a l (Layout y))
-> y
(Support y -> Context y (Layout y)) -> y
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore ((Proxy (ArithmeticCircuit a l)
  -> ArithmeticCircuit a l (Layout y))
 -> y)
-> (f
    -> Proxy (ArithmeticCircuit a l)
    -> ArithmeticCircuit a l (Layout y))
-> f
-> y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a l (Layout y)
-> Proxy (ArithmeticCircuit a l)
-> ArithmeticCircuit a l (Layout y)
forall a b. a -> b -> a
const (ArithmeticCircuit a l (Layout y)
 -> Proxy (ArithmeticCircuit a l)
 -> ArithmeticCircuit a l (Layout y))
-> (f -> ArithmeticCircuit a l (Layout y))
-> f
-> Proxy (ArithmeticCircuit a l)
-> ArithmeticCircuit a l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a l (Layout y)
-> ArithmeticCircuit a l (Layout y)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
optimize (ArithmeticCircuit a l (Layout y)
 -> ArithmeticCircuit a l (Layout y))
-> (f -> ArithmeticCircuit a l (Layout y))
-> f
-> ArithmeticCircuit a l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a l (Layout y)
-> ArithmeticCircuit a l (Layout y)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Traversable f) =>
c f -> c f
forceOne (ArithmeticCircuit a l (Layout y)
 -> ArithmeticCircuit a l (Layout y))
-> (f -> ArithmeticCircuit a l (Layout y))
-> f
-> ArithmeticCircuit a l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (c :: (Type -> Type) -> Type) f s.
(c ~ ArithmeticCircuit a (Layout s), SymbolicData f, Context f ~ c,
 Support f ~ s, SymbolicInput s, Context s ~ c, Symbolic c) =>
f -> c (Layout f)
solder @a

-- | Compiles function `f` into an arithmetic circuit.
compile ::
    forall a c f s l y .
    ( c ~ ArithmeticCircuit a l
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , SymbolicInput s
    , Context s ~ c
    , Layout s ~ l
    , SymbolicData y
    , Context y ~ c
    , Support y ~ Proxy c
    , Layout f ~ Layout y
    , Symbolic c
    ) => f -> y
compile :: forall a (c :: (Type -> Type) -> Type) f s (l :: Type -> Type) y.
(c ~ ArithmeticCircuit a l, SymbolicData f, Context f ~ c,
 Support f ~ s, SymbolicInput s, Context s ~ c, Layout s ~ l,
 SymbolicData y, Context y ~ c, Support y ~ Proxy c,
 Layout f ~ Layout y, Symbolic c) =>
f -> y
compile = (Proxy (ArithmeticCircuit a l) -> ArithmeticCircuit a l (Layout y))
-> y
(Support y -> Context y (Layout y)) -> y
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore ((Proxy (ArithmeticCircuit a l)
  -> ArithmeticCircuit a l (Layout y))
 -> y)
-> (f
    -> Proxy (ArithmeticCircuit a l)
    -> ArithmeticCircuit a l (Layout y))
-> f
-> y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a l (Layout y)
-> Proxy (ArithmeticCircuit a l)
-> ArithmeticCircuit a l (Layout y)
forall a b. a -> b -> a
const (ArithmeticCircuit a l (Layout y)
 -> Proxy (ArithmeticCircuit a l)
 -> ArithmeticCircuit a l (Layout y))
-> (f -> ArithmeticCircuit a l (Layout y))
-> f
-> Proxy (ArithmeticCircuit a l)
-> ArithmeticCircuit a l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a l (Layout y)
-> ArithmeticCircuit a l (Layout y)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
optimize (ArithmeticCircuit a l (Layout y)
 -> ArithmeticCircuit a l (Layout y))
-> (f -> ArithmeticCircuit a l (Layout y))
-> f
-> ArithmeticCircuit a l (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (c :: (Type -> Type) -> Type) f s.
(c ~ ArithmeticCircuit a (Layout s), SymbolicData f, Context f ~ c,
 Support f ~ s, SymbolicInput s, Context s ~ c, Symbolic c) =>
f -> c (Layout f)
solder @a

-- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file.
compileIO ::
    forall a c f s l .
    ( c ~ ArithmeticCircuit a l
    , FromJSON a
    , ToJSON a
    , SymbolicData f
    , Context f ~ c
    , Support f ~ s
    , ToJSON (Layout f (Var a l))
    , SymbolicInput s
    , Context s ~ c
    , Layout s ~ l
    , FromJSON (Rep l)
    , ToJSON (Rep l)
    , Symbolic c
    ) => FilePath -> f -> IO ()
compileIO :: forall a (c :: (Type -> Type) -> Type) f s (l :: Type -> Type).
(c ~ ArithmeticCircuit a l, FromJSON a, ToJSON a, SymbolicData f,
 Context f ~ c, Support f ~ s, ToJSON (Layout f (Var a l)),
 SymbolicInput s, Context s ~ c, Layout s ~ l, FromJSON (Rep l),
 ToJSON (Rep l), Symbolic c) =>
FilePath -> f -> IO ()
compileIO FilePath
scriptFile f
f = do
    let ac :: c (Layout f)
ac = ArithmeticCircuit a l (Layout f)
-> ArithmeticCircuit a l (Layout f)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
optimize (forall a (c :: (Type -> Type) -> Type) f s.
(c ~ ArithmeticCircuit a (Layout s), SymbolicData f, Context f ~ c,
 Support f ~ s, SymbolicInput s, Context s ~ c, Symbolic c) =>
f -> c (Layout f)
solder @a f
f) :: c (Layout f)

    FilePath -> IO ()
putStrLn FilePath
"\nCompiling the script...\n"

    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of constraints: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a l (Layout f) -> Natural
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeN c (Layout f)
ArithmeticCircuit a l (Layout f)
ac)
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of variables: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a l (Layout f) -> Natural
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> Natural
acSizeM c (Layout f)
ArithmeticCircuit a l (Layout f)
ac)
    FilePath -> ArithmeticCircuit a l (Layout f) -> IO ()
forall a. ToJSON a => FilePath -> a -> IO ()
writeFileJSON FilePath
scriptFile c (Layout f)
ArithmeticCircuit a l (Layout f)
ac
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Script saved: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
scriptFile