{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
        mapVarArithmeticCircuit,
        ArithmeticCircuitTest(..)
    ) where

import           Data.Functor.Rep                                    (Representable (..))
import           Data.Map                                            hiding (drop, foldl, foldr, fromList, map, null,
                                                                      splitAt, take, toList)
import qualified Data.Map                                            as Map
import           GHC.IsList                                          (IsList (..))
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Num (..), drop, length, product, splitAt,
                                                                      sum, take, (!!), (^))
import           Test.QuickCheck                                     (Arbitrary (arbitrary), Gen)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Polynomials.Multivariate
import           ZkFold.Base.Data.ByteString                         (toByteString)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), SysVar (..),
                                                                      Var (..), VarField, getAllVars)

-- This module contains functions for mapping variables in arithmetic circuits.

data ArithmeticCircuitTest a i o = ArithmeticCircuitTest
    {
        forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuitTest a i o -> ArithmeticCircuit a i o
arithmeticCircuit :: ArithmeticCircuit a i o
        , forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuitTest a i o -> i a
witnessInput    :: i a
    }

instance (Show (ArithmeticCircuit a i o), Show a, Show (i a)) => Show (ArithmeticCircuitTest a i o) where
    show :: ArithmeticCircuitTest a i o -> String
show (ArithmeticCircuitTest ArithmeticCircuit a i o
ac i a
wi) = ArithmeticCircuit a i o -> String
forall a. Show a => a -> String
show ArithmeticCircuit a i o
ac String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
",\nwitnessInput: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ i a -> String
forall a. Show a => a -> String
show i a
wi

instance (Arithmetic a, Arbitrary (i a), Arbitrary (ArithmeticCircuit a i f), Representable i) => Arbitrary (ArithmeticCircuitTest a i f) where
    arbitrary :: Gen (ArithmeticCircuitTest a i f)
    arbitrary :: Gen (ArithmeticCircuitTest a i f)
arbitrary = do
        ArithmeticCircuit a i f
ac <- Gen (ArithmeticCircuit a i f)
forall a. Arbitrary a => Gen a
arbitrary
        i a
wi <- Gen (i a)
forall a. Arbitrary a => Gen a
arbitrary
        ArithmeticCircuitTest a i f -> Gen (ArithmeticCircuitTest a i f)
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ArithmeticCircuitTest {
            arithmeticCircuit :: ArithmeticCircuit a i f
arithmeticCircuit = ArithmeticCircuit a i f
ac
            , witnessInput :: i a
witnessInput = i a
wi
            }

mapVarArithmeticCircuit :: (Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o
mapVarArithmeticCircuit :: forall a (o :: Type -> Type) (i :: Type -> Type).
(Field a, Eq a, Functor o, Ord (Rep i), Representable i,
 Foldable i) =>
ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o
mapVarArithmeticCircuit (ArithmeticCircuitTest ArithmeticCircuit a i o
ac i a
wi) =
    let vars :: [ByteString]
vars = [ByteString
v | NewVar ByteString
v <- ArithmeticCircuit a i o -> [SysVar i]
forall a (i :: Type -> Type) (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a i o -> [SysVar i]
getAllVars ArithmeticCircuit a i o
ac]
        asc :: [ByteString]
asc = [ forall a. Binary a => a -> ByteString
toByteString @VarField (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
x) | Natural
x <- [Natural
0..] ]
        forward :: Map ByteString ByteString
forward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
vars [ByteString]
asc
        backward :: Map ByteString ByteString
backward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc [ByteString]
vars
        varF :: SysVar i -> SysVar i
varF (InVar Rep i
v)  = Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar Rep i
v
        varF (NewVar ByteString
v) = ByteString -> SysVar i
forall (i :: Type -> Type). ByteString -> SysVar i
NewVar (Map ByteString ByteString
forward Map ByteString ByteString -> ByteString -> ByteString
forall k a. Ord k => Map k a -> k -> a
! ByteString
v)
        mappedCircuit :: ArithmeticCircuit a i o
mappedCircuit = ArithmeticCircuit a i o
ac
            {
                acSystem  = fromList $ zip asc $ evalPolynomial evalMonomial (var . varF) <$> elems (acSystem ac),
                -- TODO: the new arithmetic circuit expects the old input variables! We should make this safer.
                acWitness = (`Map.compose` backward) $ (\i a -> Map ByteString a -> a
f i a
i Map ByteString a
m -> i a -> Map ByteString a -> a
f i a
i (Map ByteString a -> Map ByteString ByteString -> Map ByteString a
forall b c a. Ord b => Map b c -> Map a b -> Map a c
Map.compose Map ByteString a
m Map ByteString ByteString
forward)) <$> acWitness ac
            }
        varG :: Var a i -> Var a i
varG = \case
          SysVar SysVar i
v -> SysVar i -> Var a i
forall a (i :: Type -> Type). SysVar i -> Var a i
SysVar (SysVar i -> SysVar i
varF SysVar i
v)
          ConstVar a
c -> a -> Var a i
forall a (i :: Type -> Type). a -> Var a i
ConstVar a
c
        mappedOutputs :: o (Var a i)
mappedOutputs = Var a i -> Var a i
varG (Var a i -> Var a i) -> o (Var a i) -> o (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a i o -> o (Var a i)
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> o (Var a i)
acOutput ArithmeticCircuit a i o
ac
    in ArithmeticCircuit a i o -> i a -> ArithmeticCircuitTest a i o
forall a (i :: Type -> Type) (o :: Type -> Type).
ArithmeticCircuit a i o -> i a -> ArithmeticCircuitTest a i o
ArithmeticCircuitTest (ArithmeticCircuit a i o
mappedCircuit {acOutput = mappedOutputs}) i a
wi