{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Class (
        SymbolicData (..),
        GSymbolicData (..)
    ) where

import           Control.Applicative              ((<*>))
import           Data.Function                    (flip, (.))
import           Data.Functor                     ((<$>))
import           Data.Functor.Rep                 (Representable (..))
import           Data.Kind                        (Type)
import           Data.Type.Equality               (type (~))
import           Data.Typeable                    (Proxy (..))
import           GHC.Generics                     (U1 (..), (:*:) (..), (:.:) (..))
import qualified GHC.Generics                     as G

import           ZkFold.Base.Algebra.Basic.Number (KnownNat)
import           ZkFold.Base.Control.HApplicative (HApplicative, hliftA2, hpure)
import           ZkFold.Base.Data.HFunctor        (hmap)
import           ZkFold.Base.Data.Package         (Package, pack)
import           ZkFold.Base.Data.Product         (fstP, sndP)
import           ZkFold.Base.Data.Vector          (Vector)

-- | A class for Symbolic data types.
class SymbolicData x where

    type Context x :: (Type -> Type) -> Type
    type Context x = GContext (G.Rep x)

    type Support x :: Type
    type Support x = GSupport (G.Rep x)

    type Layout x :: Type -> Type
    type Layout x = GLayout (G.Rep x)

    -- | Returns the circuit that makes up `x`.
    pieces :: x -> Support x -> Context x (Layout x)
    default pieces
      :: ( G.Generic x
         , GSymbolicData (G.Rep x)
         , Context x ~ GContext (G.Rep x)
         , Support x ~ GSupport (G.Rep x)
         , Layout x ~ GLayout (G.Rep x)
         )
      => x -> Support x -> Context x (Layout x)
    pieces x
x Support x
supp = Rep x Any -> GSupport (Rep x) -> GContext (Rep x) (GLayout (Rep x))
forall x.
Rep x x -> GSupport (Rep x) -> GContext (Rep x) (GLayout (Rep x))
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
u x -> GSupport u -> GContext u (GLayout u)
gpieces (x -> Rep x Any
forall x. x -> Rep x x
forall a x. Generic a => a -> Rep a x
G.from x
x) GSupport (Rep x)
Support x
supp

    -- | Restores `x` from the circuit's outputs.
    restore :: (Support x -> Context x (Layout x)) -> x
    default restore
      :: ( G.Generic x
         , GSymbolicData (G.Rep x)
         , Context x ~ GContext (G.Rep x)
         , Support x ~ GSupport (G.Rep x)
         , Layout x ~ GLayout (G.Rep x)
         )
      => (Support x -> Context x (Layout x)) -> x
    restore Support x -> Context x (Layout x)
f = Rep x Any -> x
forall a x. Generic a => Rep a x -> a
forall x. Rep x x -> x
G.to ((GSupport (Rep x) -> GContext (Rep x) (GLayout (Rep x)))
-> Rep x Any
forall x.
(GSupport (Rep x) -> GContext (Rep x) (GLayout (Rep x))) -> Rep x x
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
(GSupport u -> GContext u (GLayout u)) -> u x
grestore GSupport (Rep x) -> GContext (Rep x) (GLayout (Rep x))
Support x -> Context x (Layout x)
f)

instance SymbolicData (c (f :: Type -> Type)) where
    type Context (c f) = c
    type Support (c f) = Proxy c
    type Layout (c f) = f

    pieces :: c f -> Support (c f) -> Context (c f) (Layout (c f))
pieces c f
x Support (c f)
_ = c f
Context (c f) (Layout (c f))
x
    restore :: (Support (c f) -> Context (c f) (Layout (c f))) -> c f
restore Support (c f) -> Context (c f) (Layout (c f))
f = Support (c f) -> Context (c f) (Layout (c f))
f Proxy c
Support (c f)
forall {k} (t :: k). Proxy t
Proxy

instance HApplicative c => SymbolicData (Proxy (c :: (Type -> Type) -> Type)) where
    type Context (Proxy c) = c
    type Support (Proxy c) = Proxy c
    type Layout (Proxy c) = U1

    pieces :: Proxy c
-> Support (Proxy c) -> Context (Proxy c) (Layout (Proxy c))
pieces Proxy c
_ Support (Proxy c)
_ = (forall a. U1 a) -> c U1
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type).
HApplicative c =>
(forall (a :: k). f a) -> c f
forall (f :: Type -> Type). (forall a. f a) -> c f
hpure U1 a
forall a. U1 a
forall k (p :: k). U1 p
U1
    restore :: (Support (Proxy c) -> Context (Proxy c) (Layout (Proxy c)))
-> Proxy c
restore Support (Proxy c) -> Context (Proxy c) (Layout (Proxy c))
_ = Proxy c
forall {k} (t :: k). Proxy t
Proxy

instance
    ( SymbolicData x
    , SymbolicData y
    , HApplicative (Context x)
    , Context x ~ Context y
    , Support x ~ Support y
    ) => SymbolicData (x, y) where

instance
    ( SymbolicData x
    , SymbolicData y
    , SymbolicData z
    , HApplicative (Context x)
    , Context x ~ Context y
    , Context y ~ Context z
    , Support x ~ Support y
    , Support y ~ Support z
    ) => SymbolicData (x, y, z) where

instance
    ( SymbolicData w
    , SymbolicData x
    , SymbolicData y
    , SymbolicData z
    , HApplicative (Context x)
    , Context w ~ Context x
    , Context x ~ Context y
    , Context y ~ Context z
    , Support w ~ Support x
    , Support x ~ Support y
    , Support y ~ Support z
    ) => SymbolicData (w, x, y, z) where

instance
    ( SymbolicData v
    , SymbolicData w
    , SymbolicData x
    , SymbolicData y
    , SymbolicData z
    , HApplicative (Context x)
    , Context v ~ Context w
    , Context w ~ Context x
    , Context x ~ Context y
    , Context y ~ Context z
    , Support v ~ Support w
    , Support w ~ Support x
    , Support x ~ Support y
    , Support y ~ Support z
    ) => SymbolicData (v, w, x, y, z) where

instance
    ( SymbolicData x
    , Package (Context x)
    , KnownNat n
    ) => SymbolicData (Vector n x) where

    type Context (Vector n x) = Context x
    type Support (Vector n x) = Support x
    type Layout (Vector n x) = Vector n :.: Layout x

    pieces :: Vector n x
-> Support (Vector n x)
-> Context (Vector n x) (Layout (Vector n x))
pieces Vector n x
xs Support (Vector n x)
i = Vector n (Context x (Layout x))
-> Context x (Vector n :.: Layout x)
forall {k1} (c :: (k1 -> Type) -> Type) (f :: Type -> Type)
       (g :: k1 -> Type).
(Package c, Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
forall (f :: Type -> Type) (g :: Type -> Type).
(Foldable f, Functor f) =>
f (Context x g) -> Context x (f :.: g)
pack ((x -> Support x -> Context x (Layout x))
-> Support x -> x -> Context x (Layout x)
forall a b c. (a -> b -> c) -> b -> a -> c
flip x -> Support x -> Context x (Layout x)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces Support x
Support (Vector n x)
i (x -> Context x (Layout x))
-> Vector n x -> Vector n (Context x (Layout x))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector n x
xs)
    restore :: (Support (Vector n x)
 -> Context (Vector n x) (Layout (Vector n x)))
-> Vector n x
restore Support (Vector n x) -> Context (Vector n x) (Layout (Vector n x))
f = (Rep (Vector n) -> x) -> Vector n x
forall a. (Rep (Vector n) -> a) -> Vector n a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (\Rep (Vector n)
i -> (Support x -> Context x (Layout x)) -> x
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore ((forall a. (:.:) (Vector n) (Layout x) a -> Layout x a)
-> Context x (Vector n :.: Layout x) -> Context x (Layout x)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a) -> Context x f -> Context x g
hmap ((Vector n (Layout x a) -> Rep (Vector n) -> Layout x a)
-> Rep (Vector n) -> Vector n (Layout x a) -> Layout x a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Vector n (Layout x a) -> Rep (Vector n) -> Layout x a
forall a. Vector n a -> Rep (Vector n) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Rep (Vector n)
i (Vector n (Layout x a) -> Layout x a)
-> ((:.:) (Vector n) (Layout x) a -> Vector n (Layout x a))
-> (:.:) (Vector n) (Layout x) a
-> Layout x a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) (Vector n) (Layout x) a -> Vector n (Layout x a)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1) (Context x (Vector n :.: Layout x) -> Context x (Layout x))
-> (Support x -> Context x (Vector n :.: Layout x))
-> Support x
-> Context x (Layout x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Support x -> Context x (Vector n :.: Layout x)
Support (Vector n x) -> Context (Vector n x) (Layout (Vector n x))
f))

instance SymbolicData f => SymbolicData (x -> f) where
    type Context (x -> f) = Context f
    type Support (x -> f) = (x, Support f)
    type Layout (x -> f) = Layout f

    pieces :: (x -> f) -> Support (x -> f) -> Context (x -> f) (Layout (x -> f))
pieces x -> f
f (x
x, Support f
i) = f -> Support f -> Context f (Layout f)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces (x -> f
f x
x) Support f
i
    restore :: (Support (x -> f) -> Context (x -> f) (Layout (x -> f))) -> x -> f
restore Support (x -> f) -> Context (x -> f) (Layout (x -> f))
f x
x = (Support f -> Context f (Layout f)) -> f
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore ((x, Support f) -> Context f (Layout f)
Support (x -> f) -> Context (x -> f) (Layout (x -> f))
f ((x, Support f) -> Context f (Layout f))
-> (Support f -> (x, Support f))
-> Support f
-> Context f (Layout f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (x
x,))

class GSymbolicData u where
    type GContext u :: (Type -> Type) -> Type
    type GSupport u :: Type
    type GLayout u :: Type -> Type

    gpieces :: u x -> GSupport u -> GContext u (GLayout u)
    grestore :: (GSupport u -> GContext u (GLayout u)) -> u x

instance
    ( GSymbolicData u
    , GSymbolicData v
    , HApplicative (GContext u)
    , GContext u ~ GContext v
    , GSupport u ~ GSupport v
    ) => GSymbolicData (u :*: v) where

    type GContext (u :*: v) = GContext u
    type GSupport (u :*: v) = GSupport u
    type GLayout (u :*: v) = GLayout u :*: GLayout v

    gpieces :: forall (x :: k).
(:*:) u v x
-> GSupport (u :*: v) -> GContext (u :*: v) (GLayout (u :*: v))
gpieces (u x
a :*: v x
b) = (forall a.
 GLayout u a -> GLayout v a -> (:*:) (GLayout u) (GLayout v) a)
-> GContext v (GLayout u)
-> GContext v (GLayout v)
-> GContext v (GLayout u :*: GLayout v)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type) (h :: k -> Type).
HApplicative c =>
(forall (a :: k). f a -> g a -> h a) -> c f -> c g -> c h
forall (f :: Type -> Type) (g :: Type -> Type) (h :: Type -> Type).
(forall a. f a -> g a -> h a)
-> GContext v f -> GContext v g -> GContext v h
hliftA2 GLayout u a -> GLayout v a -> (:*:) (GLayout u) (GLayout v) a
forall a.
GLayout u a -> GLayout v a -> (:*:) (GLayout u) (GLayout v) a
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) (GContext v (GLayout u)
 -> GContext v (GLayout v) -> GContext v (GLayout u :*: GLayout v))
-> (GSupport v -> GContext v (GLayout u))
-> GSupport v
-> GContext v (GLayout v)
-> GContext v (GLayout u :*: GLayout v)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> u x -> GSupport u -> GContext u (GLayout u)
forall (x :: k). u x -> GSupport u -> GContext u (GLayout u)
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
u x -> GSupport u -> GContext u (GLayout u)
gpieces u x
a (GSupport v
 -> GContext v (GLayout v) -> GContext v (GLayout u :*: GLayout v))
-> (GSupport v -> GContext v (GLayout v))
-> GSupport v
-> GContext v (GLayout u :*: GLayout v)
forall a b.
(GSupport v -> a -> b) -> (GSupport v -> a) -> GSupport v -> b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> v x -> GSupport v -> GContext v (GLayout v)
forall (x :: k). v x -> GSupport v -> GContext v (GLayout v)
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
u x -> GSupport u -> GContext u (GLayout u)
gpieces v x
b
    grestore :: forall (x :: k).
(GSupport (u :*: v) -> GContext (u :*: v) (GLayout (u :*: v)))
-> (:*:) u v x
grestore GSupport (u :*: v) -> GContext (u :*: v) (GLayout (u :*: v))
f = (GSupport u -> GContext u (GLayout u)) -> u x
forall (x :: k). (GSupport u -> GContext u (GLayout u)) -> u x
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
(GSupport u -> GContext u (GLayout u)) -> u x
grestore ((forall a. (:*:) (GLayout u) (GLayout v) a -> GLayout u a)
-> GContext v (GLayout u :*: GLayout v) -> GContext v (GLayout u)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a) -> GContext v f -> GContext v g
hmap (:*:) (GLayout u) (GLayout v) a -> GLayout u a
forall a. (:*:) (GLayout u) (GLayout v) a -> GLayout u a
forall {k} (f :: k -> Type) (g :: k -> Type) (a :: k).
(:*:) f g a -> f a
fstP (GContext v (GLayout u :*: GLayout v) -> GContext v (GLayout u))
-> (GSupport v -> GContext v (GLayout u :*: GLayout v))
-> GSupport v
-> GContext v (GLayout u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GSupport v -> GContext v (GLayout u :*: GLayout v)
GSupport (u :*: v) -> GContext (u :*: v) (GLayout (u :*: v))
f) u x -> v x -> (:*:) u v x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: (GSupport v -> GContext v (GLayout v)) -> v x
forall (x :: k). (GSupport v -> GContext v (GLayout v)) -> v x
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
(GSupport u -> GContext u (GLayout u)) -> u x
grestore ((forall a. (:*:) (GLayout u) (GLayout v) a -> GLayout v a)
-> GContext v (GLayout u :*: GLayout v) -> GContext v (GLayout v)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a) -> GContext v f -> GContext v g
hmap (:*:) (GLayout u) (GLayout v) a -> GLayout v a
forall a. (:*:) (GLayout u) (GLayout v) a -> GLayout v a
forall {k} (f :: k -> Type) (g :: k -> Type) (a :: k).
(:*:) f g a -> g a
sndP (GContext v (GLayout u :*: GLayout v) -> GContext v (GLayout v))
-> (GSupport v -> GContext v (GLayout u :*: GLayout v))
-> GSupport v
-> GContext v (GLayout v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GSupport v -> GContext v (GLayout u :*: GLayout v)
GSupport (u :*: v) -> GContext (u :*: v) (GLayout (u :*: v))
f)

instance GSymbolicData f => GSymbolicData (G.M1 i c f) where
    type GContext (G.M1 i c f) = GContext f
    type GSupport (G.M1 i c f) = GSupport f
    type GLayout (G.M1 i c f) = GLayout f
    gpieces :: forall (x :: k).
M1 i c f x
-> GSupport (M1 i c f) -> GContext (M1 i c f) (GLayout (M1 i c f))
gpieces (G.M1 f x
a) = f x -> GSupport f -> GContext f (GLayout f)
forall (x :: k). f x -> GSupport f -> GContext f (GLayout f)
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
u x -> GSupport u -> GContext u (GLayout u)
gpieces f x
a
    grestore :: forall (x :: k).
(GSupport (M1 i c f) -> GContext (M1 i c f) (GLayout (M1 i c f)))
-> M1 i c f x
grestore GSupport (M1 i c f) -> GContext (M1 i c f) (GLayout (M1 i c f))
f = f x -> M1 i c f x
forall k i (c :: Meta) (f :: k -> Type) (p :: k). f p -> M1 i c f p
G.M1 ((GSupport f -> GContext f (GLayout f)) -> f x
forall (x :: k). (GSupport f -> GContext f (GLayout f)) -> f x
forall {k} (u :: k -> Type) (x :: k).
GSymbolicData u =>
(GSupport u -> GContext u (GLayout u)) -> u x
grestore GSupport f -> GContext f (GLayout f)
GSupport (M1 i c f) -> GContext (M1 i c f) (GLayout (M1 i c f))
f)

instance SymbolicData x => GSymbolicData (G.Rec0 x) where
    type GContext (G.Rec0 x) = Context x
    type GSupport (G.Rec0 x) = Support x
    type GLayout (G.Rec0 x) = Layout x
    gpieces :: forall (x :: k).
Rec0 x x
-> GSupport (Rec0 x) -> GContext (Rec0 x) (GLayout (Rec0 x))
gpieces (G.K1 x
x) = x -> Support x -> Context x (Layout x)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces x
x
    grestore :: forall (x :: k).
(GSupport (Rec0 x) -> GContext (Rec0 x) (GLayout (Rec0 x)))
-> Rec0 x x
grestore GSupport (Rec0 x) -> GContext (Rec0 x) (GLayout (Rec0 x))
f = x -> K1 R x x
forall k i c (p :: k). c -> K1 i c p
G.K1 ((Support x -> Context x (Layout x)) -> x
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore GSupport (Rec0 x) -> GContext (Rec0 x) (GLayout (Rec0 x))
Support x -> Context x (Layout x)
f)