{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
        mapVarArithmeticCircuit,
    ) where

import           Data.Bifunctor                                      (bimap)
import           Data.Functor.Rep                                    (Representable (..))
import           Data.Map                                            hiding (drop, foldl, foldr, fromList, map, null,
                                                                      splitAt, take, toList)
import qualified Data.Map                                            as Map
import qualified Data.Set                                            as Set
import           GHC.IsList                                          (IsList (..))
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Num (..), drop, length, product, splitAt,
                                                                      sum, take, (!!), (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Polynomials.Multivariate
import           ZkFold.Base.Data.ByteString                         (toByteString)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), SysVar (..), Var (..),
                                                                      VarField, WitVar (..), getAllVars)

-- This module contains functions for mapping variables in arithmetic circuits.

mapVarArithmeticCircuit ::
  (Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
mapVarArithmeticCircuit :: forall a (o :: Type -> Type) (i :: Type -> Type)
       (p :: Type -> Type).
(Field a, Eq a, Functor o, Ord (Rep i), Representable i,
 Foldable i) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
mapVarArithmeticCircuit ArithmeticCircuit a p i o
ac =
    let vars :: [ByteString]
vars = [ByteString
v | NewVar ByteString
v <- ArithmeticCircuit a p i o -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars ArithmeticCircuit a p i o
ac]
        asc :: [ByteString]
asc = [ forall a. Binary a => a -> ByteString
toByteString @VarField (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
x) | Natural
x <- [Natural
0..] ]
        forward :: Map ByteString ByteString
forward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
vars [ByteString]
asc
        backward :: Map ByteString ByteString
backward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc [ByteString]
vars
        varF :: SysVar i -> SysVar i
varF (InVar Rep i
v)  = Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar Rep i
v
        varF (NewVar ByteString
v) = ByteString -> SysVar i
forall (i :: Type -> Type). ByteString -> SysVar i
NewVar (Map ByteString ByteString
forward Map ByteString ByteString -> ByteString -> ByteString
forall k a. Ord k => Map k a -> k -> a
! ByteString
v)
        oVarF :: Var a i -> Var a i
oVarF (LinVar a
k SysVar i
v a
b) = a -> SysVar i -> a -> Var a i
forall a (i :: Type -> Type). a -> SysVar i -> a -> Var a i
LinVar a
k (SysVar i -> SysVar i
varF SysVar i
v) a
b
        oVarF (ConstVar a
c)   = a -> Var a i
forall a (i :: Type -> Type). a -> Var a i
ConstVar a
c
        witF :: WitVar p i -> WitVar p i
witF (WSysVar SysVar i
v) = SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar (SysVar i -> SysVar i
varF SysVar i
v)
        witF (WExVar Rep p
v)  = Rep p -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type). Rep p -> WitVar p i
WExVar Rep p
v
     in ArithmeticCircuit
          { acRange :: MonoidalMap a (Set (SysVar i))
acRange   = (SysVar i -> SysVar i) -> Set (SysVar i) -> Set (SysVar i)
forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map SysVar i -> SysVar i
varF (Set (SysVar i) -> Set (SysVar i))
-> MonoidalMap a (Set (SysVar i)) -> MonoidalMap a (Set (SysVar i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> MonoidalMap a (Set (SysVar i))
acRange ArithmeticCircuit a p i o
ac
          , acSystem :: Map ByteString (Constraint a i)
acSystem  = [Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i)
forall l. IsList l => [Item l] -> l
fromList ([Item (Map ByteString (Constraint a i))]
 -> Map ByteString (Constraint a i))
-> [Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i)
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [Constraint a i] -> [(ByteString, Constraint a i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc ([Constraint a i] -> [(ByteString, Constraint a i)])
-> [Constraint a i] -> [(ByteString, Constraint a i)]
forall a b. (a -> b) -> a -> b
$ ((SysVar i -> Constraint a i)
 -> Mono (SysVar i) Natural -> Constraint a i)
-> (SysVar i -> Constraint a i) -> Constraint a i -> Constraint a i
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> Constraint a i)
-> Mono (SysVar i) Natural -> Constraint a i
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial (SysVar i -> Constraint a i
forall c i j. Polynomial c i j => i -> Poly c i j
var (SysVar i -> Constraint a i)
-> (SysVar i -> SysVar i) -> SysVar i -> Constraint a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> SysVar i
varF) (Constraint a i -> Constraint a i)
-> [Constraint a i] -> [Constraint a i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ByteString (Constraint a i) -> [Constraint a i]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
ac)
          , acWitness :: Map ByteString (CircuitWitness a p i)
acWitness = ((WitVar p i -> WitVar p i)
-> CircuitWitness a p i -> CircuitWitness a p i
forall a b. (a -> b) -> WitnessF a a -> WitnessF a b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap WitVar p i -> WitVar p i
witF (CircuitWitness a p i -> CircuitWitness a p i)
-> Map ByteString (CircuitWitness a p i)
-> Map ByteString (CircuitWitness a p i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
acWitness ArithmeticCircuit a p i o
ac) Map ByteString (CircuitWitness a p i)
-> Map ByteString ByteString
-> Map ByteString (CircuitWitness a p i)
forall b c a. Ord b => Map b c -> Map a b -> Map a c
`Map.compose` Map ByteString ByteString
backward
          , acFold :: Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
acFold = (Var a i -> Var a i)
-> (CircuitWitness a p i -> CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i)
forall a b c d.
(a -> b) -> (c -> d) -> CircuitFold a a c -> CircuitFold a b d
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Var a i -> Var a i
oVarF ((WitVar p i -> WitVar p i)
-> CircuitWitness a p i -> CircuitWitness a p i
forall a b. (a -> b) -> WitnessF a a -> WitnessF a b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap WitVar p i -> WitVar p i
witF) (CircuitFold a (Var a i) (CircuitWitness a p i)
 -> CircuitFold a (Var a i) (CircuitWitness a p i))
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
acFold ArithmeticCircuit a p i o
ac
          , acOutput :: o (Var a i)
acOutput  = Var a i -> Var a i
oVarF (Var a i -> Var a i) -> o (Var a i) -> o (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o
ac
          }