{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Conditional where

import           Control.Monad.Representable.Reader (Representable, mzipWithRep)
import           Data.Function                      (($))
import           Data.Traversable                   (Traversable)
import           Data.Type.Equality                 (type (~))
import           GHC.Generics                       (Par1 (Par1))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool          (Bool (Bool), BoolType)
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators   (mzipWithMRep)
import           ZkFold.Symbolic.MonadCircuit       (newAssigned)

class BoolType b => Conditional b a where
    -- | Properties:
    --
    -- [On true] @bool onFalse onTrue 'true' == onTrue@
    --
    -- [On false] @bool onFalse onTrue 'false' == onFalse@
    bool :: a -> a -> b -> a

gif :: Conditional b a => b -> a -> a -> a
gif :: forall b a. Conditional b a => b -> a -> a -> a
gif b
b a
x a
y = a -> a -> b -> a
forall b a. Conditional b a => a -> a -> b -> a
bool a
y a
x b
b

(?) :: Conditional b a => b -> a -> a -> a
? :: forall b a. Conditional b a => b -> a -> a -> a
(?) = b -> a -> a -> a
forall b a. Conditional b a => b -> a -> a -> a
gif

instance ( SymbolicData x, Context x ~ c, Symbolic c
         , Representable (Layout x), Traversable (Layout x)
         , Representable (Payload x)
         ) => Conditional (Bool c) x where
    bool :: x -> x -> Bool c -> x
bool x
x x
y (Bool c Par1
b) = (Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
forall x (c :: (Type -> Type) -> Type).
(SymbolicData x, Context x ~ c) =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
forall (c :: (Type -> Type) -> Type).
(Context x ~ c) =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
restore ((Support x -> (c (Layout x), Payload x (WitnessField c))) -> x)
-> (Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
forall a b. (a -> b) -> a -> b
$ \Support x
s ->
      ( c Par1
-> c (Layout x)
-> c (Layout x)
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Layout x, Layout x] (Layout x) i m)
-> c (Layout x)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type) (k :: Type -> Type).
Symbolic c =>
c f -> c g -> c h -> CircuitFun '[f, g, h] k c -> c k
fromCircuit3F c Par1
b (x -> Support x -> Context x (Layout x)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
arithmetize x
x Support x
s) (x -> Support x -> Context x (Layout x)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
arithmetize x
y Support x
s) ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Par1, Layout x, Layout x] (Layout x) i m)
 -> c (Layout x))
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Layout x, Layout x] (Layout x) i m)
-> c (Layout x)
forall a b. (a -> b) -> a -> b
$ \(Par1 i
c) ->
          (i -> i -> m i) -> Layout x i -> Layout x i -> m (Layout x i)
forall (f :: Type -> Type) (m :: Type -> Type) a b c.
(Representable f, Traversable f, Applicative m) =>
(a -> b -> m c) -> f a -> f b -> m (f c)
mzipWithMRep ((i -> i -> m i) -> Layout x i -> Layout x i -> m (Layout x i))
-> (i -> i -> m i) -> Layout x i -> Layout x i -> m (Layout x i)
forall a b. (a -> b) -> a -> b
$ \i
i i
j -> do
            i
i' <- ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> (x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
c) x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
w i
i)
            i
j' <- ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
c x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
w i
j)
            ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i' x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
w i
j')
      , let Par1 WitnessField c
wb = c Par1 -> Par1 (WitnessField c)
forall (f :: Type -> Type). Functor f => c f -> f (WitnessField c)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Functor f) =>
c f -> f (WitnessField c)
witnessF c Par1
b
         in (WitnessField c -> WitnessField c -> WitnessField c)
-> Payload x (WitnessField c)
-> Payload x (WitnessField c)
-> Payload x (WitnessField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep
              (\WitnessField c
wx WitnessField c
wy -> (WitnessField c
forall a. MultiplicativeMonoid a => a
one WitnessField c -> WitnessField c -> WitnessField c
forall a. AdditiveGroup a => a -> a -> a
- WitnessField c
wb) WitnessField c -> WitnessField c -> WitnessField c
forall a. MultiplicativeSemigroup a => a -> a -> a
* WitnessField c
wx WitnessField c -> WitnessField c -> WitnessField c
forall a. AdditiveSemigroup a => a -> a -> a
+ WitnessField c
wb WitnessField c -> WitnessField c -> WitnessField c
forall a. MultiplicativeSemigroup a => a -> a -> a
* WitnessField c
wy)
              (x -> Support x -> Payload x (WitnessField (Context x))
forall x.
SymbolicData x =>
x -> Support x -> Payload x (WitnessField (Context x))
payload x
x Support x
s)
              (x -> Support x -> Payload x (WitnessField (Context x))
forall x.
SymbolicData x =>
x -> Support x -> Payload x (WitnessField (Context x))
payload x
y Support x
s)
      )