{-# LANGUAGE AllowAmbiguousTypes #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
        mapVarArithmeticCircuit,
        mapVarWitness
    ) where

import           Data.Containers.ListUtils                           (nubOrd)
import           Data.List                                           (sort)
import           Data.Map                                            hiding (drop, foldl, foldr, fromList, map, null,
                                                                      splitAt, take, toList)
import           GHC.IsList                                          (IsList (..))
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Num (..), drop, length, product, splitAt,
                                                                      sum, take, (!!), (^))

import           ZkFold.Base.Algebra.Basic.Class                     (MultiplicativeMonoid (..))
import           ZkFold.Base.Algebra.Polynomials.Multivariate
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..))

-- This module contains functions for mapping variables in arithmetic circuits.

mapVarWitness :: [Natural] -> (Map Natural a -> Map Natural a)
mapVarWitness :: forall a. [Natural] -> Map Natural a -> Map Natural a
mapVarWitness [Natural]
vars = (Natural -> Natural) -> Map Natural a -> Map Natural a
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys ([Natural] -> Natural -> Natural
mapVar [Natural]
vars)

mapVarArithmeticCircuit :: MultiplicativeMonoid a => ArithmeticCircuit a -> ArithmeticCircuit a
mapVarArithmeticCircuit :: forall a.
MultiplicativeMonoid a =>
ArithmeticCircuit a -> ArithmeticCircuit a
mapVarArithmeticCircuit ArithmeticCircuit a
ac =
    let vars :: [Natural]
vars = [Natural] -> [Natural]
forall a. Ord a => [a] -> [a]
nubOrd ([Natural] -> [Natural]) -> [Natural] -> [Natural]
forall a b. (a -> b) -> a -> b
$ [Natural] -> [Natural]
forall a. Ord a => [a] -> [a]
sort ([Natural] -> [Natural]) -> [Natural] -> [Natural]
forall a b. (a -> b) -> a -> b
$ Natural
0 Natural -> [Natural] -> [Natural]
forall a. a -> [a] -> [a]
: (Constraint a -> [Natural]) -> [Constraint a] -> [Natural]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap (Set Natural -> [Natural]
Set Natural -> [Item (Set Natural)]
forall l. IsList l => l -> [Item l]
toList (Set Natural -> [Natural])
-> (Constraint a -> Set Natural) -> Constraint a -> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint a -> Set Natural
forall c. MultiplicativeMonoid c => Polynomial' c -> Set Natural
variables) (Map Natural (Constraint a) -> [Constraint a]
forall k a. Map k a -> [a]
elems (Map Natural (Constraint a) -> [Constraint a])
-> Map Natural (Constraint a) -> [Constraint a]
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a -> Map Natural (Constraint a)
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem ArithmeticCircuit a
ac)
    in ArithmeticCircuit a
ac
    {
        acSystem  = fromList $ zip [0..] $ mapVarPolynomial vars <$> elems (acSystem ac),
        -- TODO: the new arithmetic circuit expects the old input variables! We should make this safer.
        acWitness = mapVarWitness vars . acWitness ac,
        acOutput  = mapVar vars $ acOutput ac
    }