{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE NoOverloadedStrings  #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Protostar.Oracle where

import           Data.Char                                      (ord)
import           Data.Proxy                                     (Proxy (..))
import qualified Data.Vector                                    as V
import           GHC.Generics
import           GHC.TypeLits
import           Prelude                                        (foldl, ($), (.), (<$>))
import qualified Prelude                                        as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Symbolic.Algorithms.Hash.MiMC           (mimcHash2)
import           ZkFold.Symbolic.Algorithms.Hash.MiMC.Constants (mimcConstants)

class RandomOracle a b where
    oracle :: a -> b
    default oracle :: (Generic a, RandomOracle' (Rep a) b) => a -> b
    oracle = Rep a Any -> b
forall a. Rep a a -> b
forall {k} (f :: k -> Type) b (a :: k).
RandomOracle' f b =>
f a -> b
oracle' (Rep a Any -> b) -> (a -> Rep a Any) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from

instance (Ring a, FromConstant P.Integer a) => RandomOracle P.Integer a where
    oracle :: Integer -> a
oracle = forall a b. RandomOracle a b => a -> b
oracle @a (a -> a) -> (Integer -> a) -> Integer -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a b. FromConstant a b => a -> b
fromConstant

instance Ring a => RandomOracle a a where
    oracle :: a -> a
oracle a
a = [a] -> a -> a -> a -> a
forall a x. (FromConstant a x, Ring x) => [a] -> a -> x -> x -> x
mimcHash2 [a]
forall a. FromConstant Integer a => [a]
mimcConstants a
a a
forall a. AdditiveMonoid a => a
zero a
forall a. AdditiveMonoid a => a
zero

instance (Ring b, RandomOracle a b) => RandomOracle [a] b where
    oracle :: [a] -> b
oracle = (b -> a -> b) -> b -> [a] -> b
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\b
acc a
x -> let y :: b
y = a -> b
forall a b. RandomOracle a b => a -> b
oracle a
x in b
y b -> b -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* b
acc b -> b -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ b
y) b
forall a. AdditiveMonoid a => a
zero

instance (Ring b, RandomOracle a b) => RandomOracle (V.Vector a) b where
    oracle :: Vector a -> b
oracle = (b -> a -> b) -> b -> Vector a -> b
forall b a. (b -> a -> b) -> b -> Vector a -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\b
acc a
x -> let y :: b
y = a -> b
forall a b. RandomOracle a b => a -> b
oracle a
x in b
y b -> b -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* b
acc b -> b -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ b
y) b
forall a. AdditiveMonoid a => a
zero

instance {-# OVERLAPPABLE #-} (Generic a, RandomOracle' (Rep a) b) => RandomOracle a b where

class RandomOracle' f b where
    oracle' :: f a -> b

instance (RandomOracle' f b, RandomOracle' g b) => RandomOracle' (f :+: g) b where
    oracle' :: forall (a :: k). (:+:) f g a -> b
oracle' (L1 f a
x) = f a -> b
forall (a :: k). f a -> b
forall {k} (f :: k -> Type) b (a :: k).
RandomOracle' f b =>
f a -> b
oracle' f a
x
    oracle' (R1 g a
x) = g a -> b
forall (a :: k). g a -> b
forall {k} (f :: k -> Type) b (a :: k).
RandomOracle' f b =>
f a -> b
oracle' g a
x

-- TODO: it is not secure if we know the preimage of 0 or (-1).
instance (RandomOracle' f b, RandomOracle' g b, Ring b) => RandomOracle' (f :*: g) b where
    oracle' :: forall (a :: k). (:*:) f g a -> b
oracle' (f a
x :*: g a
y) =
        let z1 :: b
z1 = f a -> b
forall (a :: k). f a -> b
forall {k} (f :: k -> Type) b (a :: k).
RandomOracle' f b =>
f a -> b
oracle' f a
x
            z2 :: b
z2 = g a -> b
forall (a :: k). g a -> b
forall {k} (f :: k -> Type) b (a :: k).
RandomOracle' f b =>
f a -> b
oracle' g a
y
        in b
z1 b -> b -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* b
z2 b -> b -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ b
z1

instance RandomOracle c b => RandomOracle' (K1 i c) b where
    oracle' :: forall (a :: k). K1 i c a -> b
oracle' (K1 c
x) = c -> b
forall a b. RandomOracle a b => a -> b
oracle c
x

-- | Handling constructors with no fields.
-- The oracle will be based on the constructor's name
--
instance {-# OVERLAPPING #-}
    ( KnownSymbol conName
    , Ring a
    , FromConstant Natural a
    ) => RandomOracle' (M1 C ('MetaCons conName fixity selectors) U1) a where
    oracle' :: forall (a :: k).
M1 C ('MetaCons conName fixity selectors) U1 a -> a
oracle' M1 C ('MetaCons conName fixity selectors) U1 a
_ = forall a b. RandomOracle a b => a -> b
oracle @[a] ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> a) -> (Char -> Natural) -> Char -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Natural
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (Int -> Natural) -> (Char -> Int) -> Char -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord) (Char -> a) -> [Char] -> [a]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy conName -> [Char]
forall (n :: Symbol) (proxy :: Symbol -> Type).
KnownSymbol n =>
proxy n -> [Char]
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @conName)

instance RandomOracle' f b => RandomOracle' (M1 c m f) b where
    oracle' :: forall (a :: k). M1 c m f a -> b
oracle' (M1 f a
x) = f a -> b
forall (a :: k). f a -> b
forall {k} (f :: k -> Type) b (a :: k).
RandomOracle' f b =>
f a -> b
oracle' f a
x