module ZkFold.Base.Protocol.Protostar.Lookup where

import           Data.Map                                    (fromList, mapWithKey)
import           Data.These                                  (These (..))
import           Data.Zip
import           GHC.Generics
import           Prelude                                     hiding (Num (..), repeat, sum, zip, zipWith, (!!), (/),
                                                              (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field             (Zp)
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Sparse.Vector              (SVector (..))
import           ZkFold.Base.Data.Vector                     (Vector)
import           ZkFold.Base.Protocol.Protostar.SpecialSound (SpecialSoundProtocol (..), SpecialSoundTranscript)
import           ZkFold.Symbolic.MonadCircuit                (Arithmetic)

data ProtostarLookup (l :: Natural) (sizeT :: Natural)
    deriving (forall x.
 ProtostarLookup l sizeT -> Rep (ProtostarLookup l sizeT) x)
-> (forall x.
    Rep (ProtostarLookup l sizeT) x -> ProtostarLookup l sizeT)
-> Generic (ProtostarLookup l sizeT)
forall (l :: Natural) (sizeT :: Natural) x.
Rep (ProtostarLookup l sizeT) x -> ProtostarLookup l sizeT
forall (l :: Natural) (sizeT :: Natural) x.
ProtostarLookup l sizeT -> Rep (ProtostarLookup l sizeT) x
forall x.
Rep (ProtostarLookup l sizeT) x -> ProtostarLookup l sizeT
forall x.
ProtostarLookup l sizeT -> Rep (ProtostarLookup l sizeT) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (l :: Natural) (sizeT :: Natural) x.
ProtostarLookup l sizeT -> Rep (ProtostarLookup l sizeT) x
from :: forall x.
ProtostarLookup l sizeT -> Rep (ProtostarLookup l sizeT) x
$cto :: forall (l :: Natural) (sizeT :: Natural) x.
Rep (ProtostarLookup l sizeT) x -> ProtostarLookup l sizeT
to :: forall x.
Rep (ProtostarLookup l sizeT) x -> ProtostarLookup l sizeT
Generic

data ProtostarLookupParams f sizeT = ProtostarLookupParams (Zp sizeT -> f) (f -> [Zp sizeT])
    deriving (forall x.
 ProtostarLookupParams f sizeT
 -> Rep (ProtostarLookupParams f sizeT) x)
-> (forall x.
    Rep (ProtostarLookupParams f sizeT) x
    -> ProtostarLookupParams f sizeT)
-> Generic (ProtostarLookupParams f sizeT)
forall x.
Rep (ProtostarLookupParams f sizeT) x
-> ProtostarLookupParams f sizeT
forall x.
ProtostarLookupParams f sizeT
-> Rep (ProtostarLookupParams f sizeT) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall f (sizeT :: Natural) x.
Rep (ProtostarLookupParams f sizeT) x
-> ProtostarLookupParams f sizeT
forall f (sizeT :: Natural) x.
ProtostarLookupParams f sizeT
-> Rep (ProtostarLookupParams f sizeT) x
$cfrom :: forall f (sizeT :: Natural) x.
ProtostarLookupParams f sizeT
-> Rep (ProtostarLookupParams f sizeT) x
from :: forall x.
ProtostarLookupParams f sizeT
-> Rep (ProtostarLookupParams f sizeT) x
$cto :: forall f (sizeT :: Natural) x.
Rep (ProtostarLookupParams f sizeT) x
-> ProtostarLookupParams f sizeT
to :: forall x.
Rep (ProtostarLookupParams f sizeT) x
-> ProtostarLookupParams f sizeT
Generic

instance (Arithmetic f, KnownNat l, KnownNat sizeT) => SpecialSoundProtocol f (ProtostarLookup l sizeT) where
    type Witness f (ProtostarLookup l sizeT)         = Vector l f
    -- ^ w in the paper
    type Input f (ProtostarLookup l sizeT)           = ProtostarLookupParams f sizeT
    -- ^ t and t^{-1} from the paper
    type ProverMessage f (ProtostarLookup l sizeT)   = (Vector l f, SVector sizeT f)
    -- ^ (w, m) or (h, g) in the paper
    type VerifierMessage f (ProtostarLookup l sizeT) = f
    type VerifierOutput f (ProtostarLookup l sizeT)  = Bool

    type Degree (ProtostarLookup l sizeT)            = 2

    outputLength :: ProtostarLookup l sizeT -> Natural
outputLength ProtostarLookup l sizeT
_ = forall (n :: Natural). KnownNat n => Natural
value @l Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ (forall (n :: Natural). KnownNat n => Natural
value @sizeT) Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
1

    rounds :: ProtostarLookup l sizeT -> Natural
    rounds :: ProtostarLookup l sizeT -> Natural
rounds ProtostarLookup l sizeT
_ = Natural
2

    prover :: ProtostarLookup l sizeT
           -> Witness f (ProtostarLookup l sizeT)
           -> Input f (ProtostarLookup l sizeT)
           -> SpecialSoundTranscript f (ProtostarLookup l sizeT)
           -> ProverMessage f (ProtostarLookup l sizeT)
    prover :: ProtostarLookup l sizeT
-> Witness f (ProtostarLookup l sizeT)
-> Input f (ProtostarLookup l sizeT)
-> SpecialSoundTranscript f (ProtostarLookup l sizeT)
-> ProverMessage f (ProtostarLookup l sizeT)
prover ProtostarLookup l sizeT
_ Witness f (ProtostarLookup l sizeT)
w (ProtostarLookupParams Zp sizeT -> f
_ f -> [Zp sizeT]
invT) [] =
        let m :: SVector sizeT f
m      = Vector l (SVector sizeT f) -> SVector sizeT f
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Map (Zp sizeT) f -> SVector sizeT f
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp sizeT) f -> SVector sizeT f)
-> (f -> Map (Zp sizeT) f) -> f -> SVector sizeT f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Zp sizeT, f)] -> Map (Zp sizeT) f
forall k a. Ord k => [(k, a)] -> Map k a
fromList ([(Zp sizeT, f)] -> Map (Zp sizeT) f)
-> (f -> [(Zp sizeT, f)]) -> f -> Map (Zp sizeT) f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Zp sizeT] -> [f] -> [(Zp sizeT, f)]
forall a b. [a] -> [b] -> [(a, b)]
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
`zip` f -> [f]
forall a. a -> [a]
forall (f :: Type -> Type) a. Repeat f => a -> f a
repeat f
forall a. MultiplicativeMonoid a => a
one) ([Zp sizeT] -> [(Zp sizeT, f)])
-> (f -> [Zp sizeT]) -> f -> [(Zp sizeT, f)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f -> [Zp sizeT]
invT (f -> SVector sizeT f) -> Vector l f -> Vector l (SVector sizeT f)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Witness f (ProtostarLookup l sizeT)
Vector l f
w)
        in (Witness f (ProtostarLookup l sizeT)
Vector l f
w, SVector sizeT f
m)
    prover ProtostarLookup l sizeT
_ Witness f (ProtostarLookup l sizeT)
_ (ProtostarLookupParams Zp sizeT -> f
t f -> [Zp sizeT]
_) [((Vector l f
w, SVector sizeT f
m), VerifierMessage f (ProtostarLookup l sizeT)
r)] =
        let h :: Vector l f
h      = (f -> f) -> Vector l f -> Vector l f
forall a b. (a -> b) -> Vector l a -> Vector l b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\f
w_i -> f
forall a. MultiplicativeMonoid a => a
one f -> f -> f
forall a. Field a => a -> a -> a
// (f
w_i f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+ f
VerifierMessage f (ProtostarLookup l sizeT)
r)) Vector l f
w
            g :: SVector sizeT f
g      = Map (Zp sizeT) f -> SVector sizeT f
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp sizeT) f -> SVector sizeT f)
-> Map (Zp sizeT) f -> SVector sizeT f
forall a b. (a -> b) -> a -> b
$ (Zp sizeT -> f -> f) -> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall k a b. (k -> a -> b) -> Map k a -> Map k b
mapWithKey (\Zp sizeT
i f
m_i -> f
m_i f -> f -> f
forall a. Field a => a -> a -> a
// (Zp sizeT -> f
t Zp sizeT
i f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+ f
VerifierMessage f (ProtostarLookup l sizeT)
r)) (Map (Zp sizeT) f -> Map (Zp sizeT) f)
-> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall a b. (a -> b) -> a -> b
$ SVector sizeT f -> Map (Zp sizeT) f
forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector SVector sizeT f
m
        in (Vector l f
h, SVector sizeT f
g)
    prover ProtostarLookup l sizeT
_ Witness f (ProtostarLookup l sizeT)
_ Input f (ProtostarLookup l sizeT)
_ SpecialSoundTranscript f (ProtostarLookup l sizeT)
_ = [Char] -> (Vector l f, SVector sizeT f)
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid transcript"

    verifier :: ProtostarLookup l sizeT
             -> Input f (ProtostarLookup l sizeT)
             -> [ProverMessage f (ProtostarLookup l sizeT)]
             -> [f]
             -> Bool
    verifier :: ProtostarLookup l sizeT
-> Input f (ProtostarLookup l sizeT)
-> [ProverMessage f (ProtostarLookup l sizeT)]
-> [f]
-> Bool
verifier ProtostarLookup l sizeT
_ (ProtostarLookupParams Zp sizeT -> f
t f -> [Zp sizeT]
_) [(Vector l f
w, SVector sizeT f
m), (Vector l f
h, SVector sizeT f
g)] [f
r, f
_] =
        let c1 :: Bool
c1 = Vector l f -> f
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum Vector l f
h f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== SVector sizeT f -> f
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum SVector sizeT f
g
            c2 :: Bool
c2 = (f -> Bool) -> Vector l f -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
forall a. MultiplicativeMonoid a => a
one) (Vector l f -> Bool) -> Vector l f -> Bool
forall a b. (a -> b) -> a -> b
$ (f -> f -> f) -> Vector l f -> Vector l f -> Vector l f
forall a b c.
(a -> b -> c) -> Vector l a -> Vector l b -> Vector l c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector l f
h ((f -> f) -> Vector l f -> Vector l f
forall a b. (a -> b) -> Vector l a -> Vector l b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+f
r) Vector l f
w)
            g' :: SVector sizeT f
g' = Map (Zp sizeT) f -> SVector sizeT f
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp sizeT) f -> SVector sizeT f)
-> Map (Zp sizeT) f -> SVector sizeT f
forall a b. (a -> b) -> a -> b
$ (Zp sizeT -> f -> f) -> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall k a b. (k -> a -> b) -> Map k a -> Map k b
mapWithKey (\Zp sizeT
i f
g_i -> f
g_i f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
* (Zp sizeT -> f
t Zp sizeT
i f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+ f
r)) (Map (Zp sizeT) f -> Map (Zp sizeT) f)
-> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall a b. (a -> b) -> a -> b
$ SVector sizeT f -> Map (Zp sizeT) f
forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector SVector sizeT f
m
            f :: These f f -> Bool
f  = \case
                This f
_ -> Bool
False
                That f
_ -> Bool
False
                These f
x f
y -> f
x f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
y
            c3 :: Bool
c3 = (Bool -> Bool) -> SVector sizeT Bool -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
forall a. MultiplicativeMonoid a => a
one) (SVector sizeT Bool -> Bool) -> SVector sizeT Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (These f f -> Bool)
-> SVector sizeT f -> SVector sizeT f -> SVector sizeT Bool
forall a b c.
(These a b -> c)
-> SVector sizeT a -> SVector sizeT b -> SVector sizeT c
forall (f :: Type -> Type) a b c.
Semialign f =>
(These a b -> c) -> f a -> f b -> f c
alignWith These f f -> Bool
f SVector sizeT f
g' SVector sizeT f
m
        in Bool
c1 Bool -> Bool -> Bool
&& Bool
c2 Bool -> Bool -> Bool
&& Bool
c3
    verifier ProtostarLookup l sizeT
_ Input f (ProtostarLookup l sizeT)
_ [ProverMessage f (ProtostarLookup l sizeT)]
_ [f]
_ = [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid transcript"