{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Conditional where

import qualified Data.Bool                        as H
import           Data.Function                    (($))
import           Data.Functor.Rep                 (Representable, mzipWithRep)
import           Data.Proxy
import           Data.Traversable                 (Traversable)
import           GHC.Generics
import qualified Prelude

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool        (Bool (Bool), BoolType)
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators (mzipWithMRep)
import           ZkFold.Symbolic.MonadCircuit     (newAssigned)

class BoolType b => Conditional b a where
    -- | Properties:
    --
    -- [On true] @bool onFalse onTrue 'true' == onTrue@
    --
    -- [On false] @bool onFalse onTrue 'false' == onFalse@
    bool :: a -> a -> b -> a
    default bool :: (Generic a, GConditional b (Rep a)) => a -> a -> b -> a
    bool a
f a
t b
b = Rep a Any -> a
forall a x. Generic a => Rep a x -> a
forall x. Rep a x -> a
to (Rep a Any -> Rep a Any -> b -> Rep a Any
forall x. Rep a x -> Rep a x -> b -> Rep a x
forall {k} b (u :: k -> Type) (x :: k).
GConditional b u =>
u x -> u x -> b -> u x
gbool (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
f) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
t) b
b)

ifThenElse :: Conditional b a => b -> a -> a -> a
ifThenElse :: forall b a. Conditional b a => b -> a -> a -> a
ifThenElse b
b a
x a
y = a -> a -> b -> a
forall b a. Conditional b a => a -> a -> b -> a
bool a
y a
x b
b

(?) :: Conditional b a => b -> a -> a -> a
? :: forall b a. Conditional b a => b -> a -> a -> a
(?) = b -> a -> a -> a
forall b a. Conditional b a => b -> a -> a -> a
ifThenElse

instance ( Symbolic c
         , Traversable f
         , Representable f
         ) => Conditional (Bool c) (c f) where
    bool :: c f -> c f -> Bool c -> c f
bool c f
x c f
y (Bool c Par1
b) = (Support (c f)
 -> (c (Layout (c f)), Payload (c f) (WitnessField c)))
-> c f
forall x (c :: (Type -> Type) -> Type).
(SymbolicData x, Context x ~ c) =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
forall (c :: (Type -> Type) -> Type).
(Context (c f) ~ c) =>
(Support (c f)
 -> (c (Layout (c f)), Payload (c f) (WitnessField c)))
-> c f
restore ((Support (c f)
  -> (c (Layout (c f)), Payload (c f) (WitnessField c)))
 -> c f)
-> (Support (c f)
    -> (c (Layout (c f)), Payload (c f) (WitnessField c)))
-> c f
forall a b. (a -> b) -> a -> b
$ \Support (c f)
s ->
      ( c Par1
-> c f
-> c f
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, f, f] (Layout (c f)) i m)
-> c (Layout (c f))
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type) (k :: Type -> Type).
Symbolic c =>
c f -> c g -> c h -> CircuitFun '[f, g, h] k c -> c k
fromCircuit3F c Par1
b (c f -> Support (c f) -> Context (c f) (Layout (c f))
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
arithmetize c f
x Support (c f)
s) (c f -> Support (c f) -> Context (c f) (Layout (c f))
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
arithmetize c f
y Support (c f)
s) ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Par1, f, f] (Layout (c f)) i m)
 -> c (Layout (c f)))
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, f, f] (Layout (c f)) i m)
-> c (Layout (c f))
forall a b. (a -> b) -> a -> b
$ \(Par1 i
c) ->
          (i -> i -> m i) -> f i -> f i -> m (f i)
forall (f :: Type -> Type) (m :: Type -> Type) a b c.
(Representable f, Traversable f, Applicative m) =>
(a -> b -> m c) -> f a -> f b -> m (f c)
mzipWithMRep ((i -> i -> m i) -> f i -> f i -> m (f i))
-> (i -> i -> m i) -> f i -> f i -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
i i
j -> do
            i
i' <- ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> (x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
c) x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
w i
i)
            i
j' <- ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
c x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
w i
j)
            ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i' x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
w i
j')
      , let Par1 WitnessField c
wb = c Par1 -> Par1 (WitnessField c)
forall (f :: Type -> Type). Functor f => c f -> f (WitnessField c)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Functor f) =>
c f -> f (WitnessField c)
witnessF c Par1
b
         in (WitnessField c -> WitnessField c -> WitnessField c)
-> U1 (WitnessField c)
-> U1 (WitnessField c)
-> U1 (WitnessField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep
              (\WitnessField c
wx WitnessField c
wy -> (WitnessField c
forall a. MultiplicativeMonoid a => a
one WitnessField c -> WitnessField c -> WitnessField c
forall a. AdditiveGroup a => a -> a -> a
- WitnessField c
wb) WitnessField c -> WitnessField c -> WitnessField c
forall a. MultiplicativeSemigroup a => a -> a -> a
* WitnessField c
wx WitnessField c -> WitnessField c -> WitnessField c
forall a. AdditiveSemigroup a => a -> a -> a
+ WitnessField c
wb WitnessField c -> WitnessField c -> WitnessField c
forall a. MultiplicativeSemigroup a => a -> a -> a
* WitnessField c
wy)
              (c f
-> Support (c f) -> Payload (c f) (WitnessField (Context (c f)))
forall x.
SymbolicData x =>
x -> Support x -> Payload x (WitnessField (Context x))
payload c f
x Support (c f)
s)
              (c f
-> Support (c f) -> Payload (c f) (WitnessField (Context (c f)))
forall x.
SymbolicData x =>
x -> Support x -> Payload x (WitnessField (Context x))
payload c f
y Support (c f)
s)
      )

deriving newtype instance Symbolic c => Conditional (Bool c) (Bool c)

instance Symbolic c => Conditional (Bool c) (Proxy c) where
  bool :: Proxy c -> Proxy c -> Bool c -> Proxy c
bool Proxy c
_ Proxy c
_ Bool c
_ = Proxy c
forall {k} (t :: k). Proxy t
Proxy

instance Conditional Prelude.Bool Prelude.Bool where bool :: Bool -> Bool -> Bool -> Bool
bool = Bool -> Bool -> Bool -> Bool
forall a. a -> a -> Bool -> a
H.bool

instance Conditional Prelude.Bool Prelude.String where bool :: String -> String -> Bool -> String
bool = String -> String -> Bool -> String
forall a. a -> a -> Bool -> a
H.bool

instance Conditional Prelude.Bool (Zp n) where bool :: Zp n -> Zp n -> Bool -> Zp n
bool = Zp n -> Zp n -> Bool -> Zp n
forall a. a -> a -> Bool -> a
H.bool

instance (KnownNat n, Conditional bool x) => Conditional bool (Vector n x) where
  bool :: Vector n x -> Vector n x -> bool -> Vector n x
bool Vector n x
fv Vector n x
tv bool
b = (x -> x -> x) -> Vector n x -> Vector n x -> Vector n x
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\x
f x
t -> x -> x -> bool -> x
forall b a. Conditional b a => a -> a -> b -> a
bool x
f x
t bool
b) Vector n x
fv Vector n x
tv

instance Conditional bool field => Conditional bool (Ext2 field i)
instance Conditional bool field => Conditional bool (Ext3 field i)

instance
  ( Conditional b x0
  , Conditional b x1
  ) => Conditional b (x0,x1)

instance
  ( Conditional b x0
  , Conditional b x1
  , Conditional b x2
  ) => Conditional b (x0,x1,x2)

instance
  ( Conditional b x0
  , Conditional b x1
  , Conditional b x2
  , Conditional b x3
  ) => Conditional b (x0,x1,x2,x3)

class BoolType b => GConditional b u where
    gbool :: u x -> u x -> b -> u x

instance (BoolType b, GConditional b u, GConditional b v) => GConditional b (u :*: v) where
  gbool :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> b -> (:*:) u v x
gbool (u x
f0 :*: v x
f1) (u x
t0 :*: v x
t1) b
b = u x -> u x -> b -> u x
forall (x :: k). u x -> u x -> b -> u x
forall {k} b (u :: k -> Type) (x :: k).
GConditional b u =>
u x -> u x -> b -> u x
gbool u x
f0 u x
t0 b
b u x -> v x -> (:*:) u v x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: v x -> v x -> b -> v x
forall (x :: k). v x -> v x -> b -> v x
forall {k} b (u :: k -> Type) (x :: k).
GConditional b u =>
u x -> u x -> b -> u x
gbool v x
f1 v x
t1 b
b

instance GConditional b v => GConditional b (M1 i c v) where
  gbool :: forall (x :: k). M1 i c v x -> M1 i c v x -> b -> M1 i c v x
gbool (M1 v x
f) (M1 v x
t) b
b = v x -> M1 i c v x
forall k i (c :: Meta) (f :: k -> Type) (p :: k). f p -> M1 i c f p
M1 (v x -> v x -> b -> v x
forall (x :: k). v x -> v x -> b -> v x
forall {k} b (u :: k -> Type) (x :: k).
GConditional b u =>
u x -> u x -> b -> u x
gbool v x
f v x
t b
b)

instance Conditional b x => GConditional b (Rec0 x) where
  gbool :: forall (x :: k). Rec0 x x -> Rec0 x x -> b -> Rec0 x x
gbool (K1 x
f) (K1 x
t) b
b = x -> K1 R x x
forall k i c (p :: k). c -> K1 i c p
K1 (x -> x -> b -> x
forall b a. Conditional b a => a -> a -> b -> a
bool x
f x
t b
b)