{-# LANGUAGE TypeOperators #-}

module ZkFold.Symbolic.Class (module ZkFold.Symbolic.Class, Arithmetic) where

import           Control.Monad
import           Data.Foldable                    (Foldable)
import           Data.Function                    ((.))
import           Data.Functor                     ((<$>))
import           Data.Kind                        (Type)
import           Data.Type.Equality               (type (~))
import           GHC.Generics                     (type (:.:) (unComp1))
import           Numeric.Natural                  (Natural)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Control.HApplicative (HApplicative (hpair, hunit))
import           ZkFold.Base.Data.Package         (Package (pack))
import           ZkFold.Base.Data.Product         (uncurryP)
import           ZkFold.Symbolic.MonadCircuit

-- | A type of mappings between functors inside a circuit.
-- @f@ is an input functor, @g@ is an output functor, @a@ is a base field.
--
-- A function is a mapping between functors inside a circuit if,
-- given an arbitrary builder of circuits @m@ over @a@ with arbitrary @i@ as
-- variables, it maps @f@ many inputs to @g@ many outputs using @m@.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.
type CircuitFun f g a = forall i m. MonadCircuit i a m => f i -> m (g i)

-- | A Symbolic DSL for performant pure computations with arithmetic circuits.
-- @c@ is a generic context in which computations are performed.
class (HApplicative c, Package c, Arithmetic (BaseField c)) => Symbolic c where
    -- | Base algebraic field over which computations are performed.
    type BaseField c :: Type

    -- | To perform computations in a generic context @c@ -- that is,
    -- to form a mapping between @c f@ and @c g@ for given @f@ and @g@ --
    -- you need to provide two things:
    --
    -- 1. An algorithm for turning @f@ into @g@ in a pure context;
    -- 2. An algorithm for turning @f@ into @g@ inside a circuit.
    --
    -- It is not however checked (yet) that the provided algorithms
    -- compute the same things.
    --
    -- If the pure-context computation is tautological to the circuit
    -- computation, use @'fromCircuitF'@.
    symbolicF :: BaseField c ~ a => c f -> (f a -> g a) -> CircuitFun f g a -> c g

    -- | A wrapper around @'symbolicF'@ which extracts the pure computation
    -- from the circuit computation using the @'Witnesses'@ newtype.
    fromCircuitF :: c f -> CircuitFun f g (BaseField c) -> c g
    fromCircuitF c f
x CircuitFun f g (BaseField c)
f = c f
-> (f (BaseField c) -> g (BaseField c))
-> CircuitFun f g (BaseField c)
-> c g
forall a (f :: Type -> Type) (g :: Type -> Type).
(BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun f g a -> c g
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun f g a -> c g
symbolicF c f
x (forall {k1} {k2} (n :: k1) (a :: k2) x. Witnesses n a x -> x
forall n a x. Witnesses n a x -> x
runWitnesses @Natural @(BaseField c) (Witnesses Natural (BaseField c) (g (BaseField c))
 -> g (BaseField c))
-> (f (BaseField c)
    -> Witnesses Natural (BaseField c) (g (BaseField c)))
-> f (BaseField c)
-> g (BaseField c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (BaseField c)
-> Witnesses Natural (BaseField c) (g (BaseField c))
CircuitFun f g (BaseField c)
f) f i -> m (g i)
CircuitFun f g (BaseField c)
f

-- | Embeds the pure value(s) into generic context @c@.
embed :: (Symbolic c, Functor f) => f (BaseField c) -> c f
embed :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Functor f) =>
f (BaseField c) -> c f
embed f (BaseField c)
cs = c U1 -> CircuitFun U1 f (BaseField c) -> c f
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun f g (BaseField c) -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun f g (BaseField c) -> c g
fromCircuitF c U1
forall {k} (c :: (k -> Type) -> Type). HApplicative c => c U1
hunit (\U1 i
_ -> f i -> m (f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (BaseField c -> i
forall a b. FromConstant a b => a -> b
fromConstant (BaseField c -> i) -> f (BaseField c) -> f i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BaseField c)
cs))

symbolic2F ::
    (Symbolic c, BaseField c ~ a) => c f -> c g -> (f a -> g a -> h a) ->
    (forall i m. MonadCircuit i a m => f i -> g i -> m (h i)) -> c h
-- | Runs the binary function from @f@ and @g@ into @h@ in a generic context @c@.
symbolic2F :: forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f
-> c g
-> (f a -> g a -> h a)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a m =>
    f i -> g i -> m (h i))
-> c h
symbolic2F c f
x c g
y f a -> g a -> h a
f forall i (m :: Type -> Type).
MonadCircuit i a m =>
f i -> g i -> m (h i)
m = c (f :*: g)
-> ((:*:) f g a -> h a) -> CircuitFun (f :*: g) h a -> c h
forall a (f :: Type -> Type) (g :: Type -> Type).
(BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun f g a -> c g
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun f g a -> c g
symbolicF (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) ((f a -> g a -> h a) -> (:*:) f g a -> h a
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f a -> g a -> h a
f) ((f i -> g i -> m (h i)) -> (:*:) f g i -> m (h i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f i -> g i -> m (h i)
forall i (m :: Type -> Type).
MonadCircuit i a m =>
f i -> g i -> m (h i)
m)

fromCircuit2F ::
    Symbolic c => c f -> c g ->
    (forall i m. MonadCircuit i (BaseField c) m => f i -> g i -> m (h i)) -> c h
-- | Runs the binary @'CircuitFun'@ in a generic context.
fromCircuit2F :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f
-> c g
-> (forall i (m :: Type -> Type).
    MonadCircuit i (BaseField c) m =>
    f i -> g i -> m (h i))
-> c h
fromCircuit2F c f
x c g
y forall i (m :: Type -> Type).
MonadCircuit i (BaseField c) m =>
f i -> g i -> m (h i)
m = c (f :*: g) -> CircuitFun (f :*: g) h (BaseField c) -> c h
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun f g (BaseField c) -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun f g (BaseField c) -> c g
fromCircuitF (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) ((f i -> g i -> m (h i)) -> (:*:) f g i -> m (h i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f i -> g i -> m (h i)
forall i (m :: Type -> Type).
MonadCircuit i (BaseField c) m =>
f i -> g i -> m (h i)
m)

symbolic3F ::
    (Symbolic c, BaseField c ~ a) => c f -> c g -> c h -> (f a -> g a -> h a -> k a) ->
    (forall i m. MonadCircuit i a m => f i -> g i -> h i -> m (k i)) -> c k
-- | Runs the ternary function from @f@, @g@ and @h@ into @k@ in a context @c@.
symbolic3F :: forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type) (k :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f
-> c g
-> c h
-> (f a -> g a -> h a -> k a)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a m =>
    f i -> g i -> h i -> m (k i))
-> c k
symbolic3F c f
x c g
y c h
z f a -> g a -> h a -> k a
f forall i (m :: Type -> Type).
MonadCircuit i a m =>
f i -> g i -> h i -> m (k i)
m = c (f :*: g)
-> c h
-> ((:*:) f g a -> h a -> k a)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a m =>
    (:*:) f g i -> h i -> m (k i))
-> c k
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f
-> c g
-> (f a -> g a -> h a)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a m =>
    f i -> g i -> m (h i))
-> c h
symbolic2F (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) c h
z ((f a -> g a -> h a -> k a) -> (:*:) f g a -> h a -> k a
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f a -> g a -> h a -> k a
f) ((f i -> g i -> h i -> m (k i)) -> (:*:) f g i -> h i -> m (k i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f i -> g i -> h i -> m (k i)
forall i (m :: Type -> Type).
MonadCircuit i a m =>
f i -> g i -> h i -> m (k i)
m)

fromCircuit3F ::
    Symbolic c => c f -> c g -> c h ->
    (forall i m. MonadCircuit i (BaseField c) m => f i -> g i -> h i -> m (k i)) -> c k
-- | Runs the ternary @'CircuitFun'@ in a generic context.
fromCircuit3F :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type) (k :: Type -> Type).
Symbolic c =>
c f
-> c g
-> c h
-> (forall i (m :: Type -> Type).
    MonadCircuit i (BaseField c) m =>
    f i -> g i -> h i -> m (k i))
-> c k
fromCircuit3F c f
x c g
y c h
z forall i (m :: Type -> Type).
MonadCircuit i (BaseField c) m =>
f i -> g i -> h i -> m (k i)
m = c (f :*: g)
-> c h
-> (forall i (m :: Type -> Type).
    MonadCircuit i (BaseField c) m =>
    (:*:) f g i -> h i -> m (k i))
-> c k
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f
-> c g
-> (forall i (m :: Type -> Type).
    MonadCircuit i (BaseField c) m =>
    f i -> g i -> m (h i))
-> c h
fromCircuit2F (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) c h
z ((f i -> g i -> h i -> m (k i)) -> (:*:) f g i -> h i -> m (k i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f i -> g i -> h i -> m (k i)
forall i (m :: Type -> Type).
MonadCircuit i (BaseField c) m =>
f i -> g i -> h i -> m (k i)
m)

symbolicVF ::
    (Symbolic c, BaseField c ~ a, Foldable f, Functor f) =>
    f (c g) -> (f (g a) -> h a) ->
    (forall i m. MonadCircuit i a m => f (g i) -> m (h i)) -> c h
-- | Given a generic context @c@, runs the function from @f@ many @c g@'s into @c h@.
symbolicVF :: forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a, Foldable f, Functor f) =>
f (c g)
-> (f (g a) -> h a)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a m =>
    f (g i) -> m (h i))
-> c h
symbolicVF f (c g)
xs f (g a) -> h a
f forall i (m :: Type -> Type).
MonadCircuit i a m =>
f (g i) -> m (h i)
m = c (f :.: g)
-> ((:.:) f g a -> h a) -> CircuitFun (f :.: g) h a -> c h
forall a (f :: Type -> Type) (g :: Type -> Type).
(BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun f g a -> c g
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun f g a -> c g
symbolicF (f (c g) -> c (f :.: g)
forall {k1} (c :: (k1 -> Type) -> Type) (f :: Type -> Type)
       (g :: k1 -> Type).
(Package c, Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
forall (f :: Type -> Type) (g :: Type -> Type).
(Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
pack f (c g)
xs) (f (g a) -> h a
f (f (g a) -> h a) -> ((:.:) f g a -> f (g a)) -> (:.:) f g a -> h a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) f g a -> f (g a)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1) (f (g i) -> m (h i)
forall i (m :: Type -> Type).
MonadCircuit i a m =>
f (g i) -> m (h i)
m (f (g i) -> m (h i))
-> ((:.:) f g i -> f (g i)) -> (:.:) f g i -> m (h i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) f g i -> f (g i)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1)

fromCircuitVF ::
    (Symbolic c, Foldable f, Functor f) => f (c g) ->
    (forall i m. MonadCircuit i (BaseField c) m => f (g i) -> m (h i)) -> c h
-- | Given a generic context @c@, runs the @'CircuitFun'@ from @f@ many @c g@'s into @c h@.
fromCircuitVF :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, Foldable f, Functor f) =>
f (c g)
-> (forall i (m :: Type -> Type).
    MonadCircuit i (BaseField c) m =>
    f (g i) -> m (h i))
-> c h
fromCircuitVF f (c g)
xs forall i (m :: Type -> Type).
MonadCircuit i (BaseField c) m =>
f (g i) -> m (h i)
m = c (f :.: g) -> CircuitFun (f :.: g) h (BaseField c) -> c h
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun f g (BaseField c) -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun f g (BaseField c) -> c g
fromCircuitF (f (c g) -> c (f :.: g)
forall {k1} (c :: (k1 -> Type) -> Type) (f :: Type -> Type)
       (g :: k1 -> Type).
(Package c, Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
forall (f :: Type -> Type) (g :: Type -> Type).
(Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
pack f (c g)
xs) (f (g i) -> m (h i)
forall i (m :: Type -> Type).
MonadCircuit i (BaseField c) m =>
f (g i) -> m (h i)
m (f (g i) -> m (h i))
-> ((:.:) f g i -> f (g i)) -> (:.:) f g i -> m (h i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) f g i -> f (g i)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1)