module ZkFold.Symbolic.Cardano.Contracts.BatchTransfer where

import           Data.Maybe                                     (fromJust)
import           Data.Zip                                       (zip)
import           Numeric.Natural                                (Natural)
import           Prelude                                        hiding (Bool, Eq (..), all, length, splitAt, zip, (&&),
                                                                 (*), (+))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Data.Vector                        (Vector, fromVector, toVector)
import           ZkFold.Symbolic.Algorithms.Hash.MiMC           (mimcHash)
import           ZkFold.Symbolic.Algorithms.Hash.MiMC.Constants (mimcConstants)
import           ZkFold.Symbolic.Cardano.Types                  (Input, Output, Transaction, paymentCredential,
                                                                 txInputs, txOutputs, txiOutput, txoAddress)
import           ZkFold.Symbolic.Compiler                       (ArithmeticCircuit, SymbolicData (pieces))
import           ZkFold.Symbolic.Data.Bool                      (Bool, BoolType (..), all)
import           ZkFold.Symbolic.Data.ByteString
import           ZkFold.Symbolic.Data.Combinators
import           ZkFold.Symbolic.Data.Eq
import           ZkFold.Symbolic.Data.UInt
import           ZkFold.Symbolic.Types                          (Symbolic)

type TxOut a = Output 10 () a
type TxIn a  = Input 10 () a
type Tx a = Transaction 6 0 11 10 () a

class Hash a x where
    hash :: x -> a

instance SymbolicData a x => Hash (ArithmeticCircuit a) x where
    hash :: x -> ArithmeticCircuit a
hash x
datum = case x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
datum of
        []         -> ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero
        [ArithmeticCircuit a
x]        -> [ArithmeticCircuit a]
-> ArithmeticCircuit a
-> ArithmeticCircuit a
-> ArithmeticCircuit a
-> ArithmeticCircuit a
forall a. Symbolic a => [a] -> a -> a -> a -> a
mimcHash [ArithmeticCircuit a]
forall a. FromConstant I a => [a]
mimcConstants ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero ArithmeticCircuit a
x
        [ArithmeticCircuit a
xL, ArithmeticCircuit a
xR]   -> [ArithmeticCircuit a]
-> ArithmeticCircuit a
-> ArithmeticCircuit a
-> ArithmeticCircuit a
-> ArithmeticCircuit a
forall a. Symbolic a => [a] -> a -> a -> a -> a
mimcHash [ArithmeticCircuit a]
forall a. FromConstant I a => [a]
mimcConstants ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero ArithmeticCircuit a
xL ArithmeticCircuit a
xR
        (ArithmeticCircuit a
xL:ArithmeticCircuit a
xR:[ArithmeticCircuit a]
xZ) -> [ArithmeticCircuit a]
-> ArithmeticCircuit a
-> ArithmeticCircuit a
-> ArithmeticCircuit a
-> ArithmeticCircuit a
forall a. Symbolic a => [a] -> a -> a -> a -> a
mimcHash (ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero ArithmeticCircuit a
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. a -> [a] -> [a]
: [ArithmeticCircuit a]
xZ [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ [ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero]) ArithmeticCircuit a
forall a. AdditiveMonoid a => a
zero ArithmeticCircuit a
xL ArithmeticCircuit a
xR

type Sig a = (StrictConv a (UInt 256 a),
    MultiplicativeSemigroup (UInt 256 a),
    Eq (Bool a) (UInt 256 a),
    Iso (UInt 256 a) (ByteString 256 a),
    Extend (ByteString 224 a) (ByteString 256 a),
    Hash a (TxOut a))

verifySignature ::
    forall a . (Symbolic a, Sig a) =>
    ByteString 224 a ->
    (TxOut a, TxOut a) ->
    ByteString 256 a ->
    Bool a
verifySignature :: forall a.
(Symbolic a, Sig a) =>
ByteString 224 a
-> (TxOut a, TxOut a) -> ByteString 256 a -> Bool a
verifySignature ByteString 224 a
pub (TxOut a
pay, TxOut a
change) ByteString 256 a
sig = (ByteString 256 a -> UInt 256 a
forall a b. Iso a b => a -> b
from ByteString 256 a
sig UInt 256 a -> UInt 256 a -> UInt 256 a
forall a. MultiplicativeSemigroup a => a -> a -> a
* UInt 256 a
base) UInt 256 a -> UInt 256 a -> Bool a
forall b a. Eq b a => a -> a -> b
== (a -> UInt 256 a
forall b a. StrictConv b a => b -> a
strictConv a
mimc UInt 256 a -> UInt 256 a -> UInt 256 a
forall a. MultiplicativeSemigroup a => a -> a -> a
* ByteString 256 a -> UInt 256 a
forall a b. Iso a b => a -> b
from (ByteString 224 a -> ByteString 256 a
forall a b. Extend a b => a -> b
extend ByteString 224 a
pub :: ByteString 256 a))
    where
        base :: UInt 256 a
        base :: UInt 256 a
base = Natural -> UInt 256 a
forall a b. FromConstant a b => a -> b
fromConstant (Natural
15112221349535400772501151409588531511454012693041857206046113283949847762202 :: Natural)

        mimc :: a
        mimc :: a
mimc = [a] -> a -> a -> a -> a
forall a. Symbolic a => [a] -> a -> a -> a -> a
mimcHash [a]
forall a. FromConstant I a => [a]
mimcConstants a
forall a. AdditiveMonoid a => a
zero (TxOut a -> a
forall a x. Hash a x => x -> a
hash TxOut a
pay) (TxOut a -> a
forall a x. Hash a x => x -> a
hash TxOut a
change)

batchTransfer :: (Symbolic a, Eq (Bool a) (TxOut a), Sig a) => Tx a -> Vector 5 (TxOut a, TxOut a, ByteString 256 a) -> Bool a
batchTransfer :: forall a.
(Symbolic a, Eq (Bool a) (TxOut a), Sig a) =>
Tx a -> Vector 5 (TxOut a, TxOut a, ByteString 256 a) -> Bool a
batchTransfer Tx a
tx Vector 5 (TxOut a, TxOut a, ByteString 256 a)
transfers =
    let -- Extract the payment credentials and verify the signatures
        pkhs :: Vector 5 (ByteString 224 a)
pkhs       = Maybe (Vector 5 (ByteString 224 a)) -> Vector 5 (ByteString 224 a)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Vector 5 (ByteString 224 a))
 -> Vector 5 (ByteString 224 a))
-> Maybe (Vector 5 (ByteString 224 a))
-> Vector 5 (ByteString 224 a)
forall a b. (a -> b) -> a -> b
$ forall (size :: Natural) a.
KnownNat size =>
[a] -> Maybe (Vector size a)
toVector @5 ([ByteString 224 a] -> Maybe (Vector 5 (ByteString 224 a)))
-> [ByteString 224 a] -> Maybe (Vector 5 (ByteString 224 a))
forall a b. (a -> b) -> a -> b
$ (Input 10 () a -> ByteString 224 a)
-> [Input 10 () a] -> [ByteString 224 a]
forall a b. (a -> b) -> [a] -> [b]
map (Address a -> ByteString 224 a
forall a. Address a -> ByteString 224 a
paymentCredential (Address a -> ByteString 224 a)
-> (Input 10 () a -> Address a)
-> Input 10 () a
-> ByteString 224 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TxOut a -> Address a
forall {k} (tokens :: Natural) (datum :: k) a.
Output tokens datum a -> Address a
txoAddress (TxOut a -> Address a)
-> (Input 10 () a -> TxOut a) -> Input 10 () a -> Address a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input 10 () a -> TxOut a
forall {k} (tokens :: Natural) (datum :: k) a.
Input tokens datum a -> Output tokens datum a
txiOutput) ([Input 10 () a] -> [ByteString 224 a])
-> [Input 10 () a] -> [ByteString 224 a]
forall a b. (a -> b) -> a -> b
$ [Input 10 () a] -> [Input 10 () a]
forall a. HasCallStack => [a] -> [a]
init ([Input 10 () a] -> [Input 10 () a])
-> [Input 10 () a] -> [Input 10 () a]
forall a b. (a -> b) -> a -> b
$ Vector 6 (Input 10 () a) -> [Input 10 () a]
forall (size :: Natural) a. Vector size a -> [a]
fromVector (Vector 6 (Input 10 () a) -> [Input 10 () a])
-> Vector 6 (Input 10 () a) -> [Input 10 () a]
forall a b. (a -> b) -> a -> b
$ Tx a -> Vector 6 (Input 10 () a)
forall {k} (inputs :: Natural) (rinputs :: Natural)
       (outputs :: Natural) (tokens :: Natural) (datum :: k) a.
Transaction inputs rinputs outputs tokens datum a
-> Vector inputs (Input tokens datum a)
txInputs Tx a
tx
        condition1 :: Bool a
condition1 = ((ByteString 224 a, (TxOut a, TxOut a, ByteString 256 a))
 -> Bool a)
-> Vector
     5 (ByteString 224 a, (TxOut a, TxOut a, ByteString 256 a))
-> Bool a
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all (\(ByteString 224 a
pkh, (TxOut a
payment, TxOut a
change, ByteString 256 a
signature)) -> ByteString 224 a
-> (TxOut a, TxOut a) -> ByteString 256 a -> Bool a
forall a.
(Symbolic a, Sig a) =>
ByteString 224 a
-> (TxOut a, TxOut a) -> ByteString 256 a -> Bool a
verifySignature ByteString 224 a
pkh (TxOut a
payment, TxOut a
change) ByteString 256 a
signature) (Vector 5 (ByteString 224 a, (TxOut a, TxOut a, ByteString 256 a))
 -> Bool a)
-> Vector
     5 (ByteString 224 a, (TxOut a, TxOut a, ByteString 256 a))
-> Bool a
forall a b. (a -> b) -> a -> b
$ Vector 5 (ByteString 224 a)
-> Vector 5 (TxOut a, TxOut a, ByteString 256 a)
-> Vector
     5 (ByteString 224 a, (TxOut a, TxOut a, ByteString 256 a))
forall a b. Vector 5 a -> Vector 5 b -> Vector 5 (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
zip Vector 5 (ByteString 224 a)
pkhs Vector 5 (TxOut a, TxOut a, ByteString 256 a)
transfers
        outputs :: [(I, TxOut a)]
outputs    = [I] -> [TxOut a] -> [(I, TxOut a)]
forall a b. [a] -> [b] -> [(a, b)]
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
zip [I
0..] ([TxOut a] -> [(I, TxOut a)])
-> (Vector 11 (TxOut a) -> [TxOut a])
-> Vector 11 (TxOut a)
-> [(I, TxOut a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TxOut a] -> [TxOut a]
forall a. HasCallStack => [a] -> [a]
init ([TxOut a] -> [TxOut a])
-> (Vector 11 (TxOut a) -> [TxOut a])
-> Vector 11 (TxOut a)
-> [TxOut a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector 11 (TxOut a) -> [TxOut a]
forall (size :: Natural) a. Vector size a -> [a]
fromVector (Vector 11 (TxOut a) -> [(I, TxOut a)])
-> Vector 11 (TxOut a) -> [(I, TxOut a)]
forall a b. (a -> b) -> a -> b
$ Tx a -> Vector 11 (TxOut a)
forall {k} (inputs :: Natural) (rinputs :: Natural)
       (outputs :: Natural) (tokens :: Natural) (datum :: k) a.
Transaction inputs rinputs outputs tokens datum a
-> Vector outputs (Output tokens datum a)
txOutputs Tx a
tx

        -- Extract the payments from the transaction and validate them
        payments :: Vector 5 (TxOut a)
payments   = Maybe (Vector 5 (TxOut a)) -> Vector 5 (TxOut a)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Vector 5 (TxOut a)) -> Vector 5 (TxOut a))
-> Maybe (Vector 5 (TxOut a)) -> Vector 5 (TxOut a)
forall a b. (a -> b) -> a -> b
$ forall (size :: Natural) a.
KnownNat size =>
[a] -> Maybe (Vector size a)
toVector @5 ([TxOut a] -> Maybe (Vector 5 (TxOut a)))
-> [TxOut a] -> Maybe (Vector 5 (TxOut a))
forall a b. (a -> b) -> a -> b
$ ((I, TxOut a) -> TxOut a) -> [(I, TxOut a)] -> [TxOut a]
forall a b. (a -> b) -> [a] -> [b]
map (I, TxOut a) -> TxOut a
forall a b. (a, b) -> b
snd ([(I, TxOut a)] -> [TxOut a]) -> [(I, TxOut a)] -> [TxOut a]
forall a b. (a -> b) -> a -> b
$ ((I, TxOut a) -> Bool) -> [(I, TxOut a)] -> [(I, TxOut a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(I
i, TxOut a
_) -> forall a. Integral a => a -> Bool
even @Integer I
i) ([(I, TxOut a)] -> [(I, TxOut a)])
-> [(I, TxOut a)] -> [(I, TxOut a)]
forall a b. (a -> b) -> a -> b
$ [(I, TxOut a)]
outputs

        condition2 :: Bool a
condition2 = ((TxOut a, (TxOut a, TxOut a, ByteString 256 a)) -> Bool a)
-> Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
-> Bool a
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all (\(TxOut a
p', (TxOut a
p, TxOut a
_, ByteString 256 a
_)) -> TxOut a
p' TxOut a -> TxOut a -> Bool a
forall b a. Eq b a => a -> a -> b
== TxOut a
p) (Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
 -> Bool a)
-> Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
-> Bool a
forall a b. (a -> b) -> a -> b
$ Vector 5 (TxOut a)
-> Vector 5 (TxOut a, TxOut a, ByteString 256 a)
-> Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
forall a b. Vector 5 a -> Vector 5 b -> Vector 5 (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
zip Vector 5 (TxOut a)
payments Vector 5 (TxOut a, TxOut a, ByteString 256 a)
transfers

        -- Extract the changes from the transaction and validate them
        changes :: Vector 5 (TxOut a)
changes    = Maybe (Vector 5 (TxOut a)) -> Vector 5 (TxOut a)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Vector 5 (TxOut a)) -> Vector 5 (TxOut a))
-> Maybe (Vector 5 (TxOut a)) -> Vector 5 (TxOut a)
forall a b. (a -> b) -> a -> b
$ forall (size :: Natural) a.
KnownNat size =>
[a] -> Maybe (Vector size a)
toVector @5 ([TxOut a] -> Maybe (Vector 5 (TxOut a)))
-> [TxOut a] -> Maybe (Vector 5 (TxOut a))
forall a b. (a -> b) -> a -> b
$ ((I, TxOut a) -> TxOut a) -> [(I, TxOut a)] -> [TxOut a]
forall a b. (a -> b) -> [a] -> [b]
map (I, TxOut a) -> TxOut a
forall a b. (a, b) -> b
snd ([(I, TxOut a)] -> [TxOut a]) -> [(I, TxOut a)] -> [TxOut a]
forall a b. (a -> b) -> a -> b
$ ((I, TxOut a) -> Bool) -> [(I, TxOut a)] -> [(I, TxOut a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(I
i, TxOut a
_) -> forall a. Integral a => a -> Bool
odd @Integer I
i) ([(I, TxOut a)] -> [(I, TxOut a)])
-> [(I, TxOut a)] -> [(I, TxOut a)]
forall a b. (a -> b) -> a -> b
$ [(I, TxOut a)]
outputs
        condition3 :: Bool a
condition3 = ((TxOut a, (TxOut a, TxOut a, ByteString 256 a)) -> Bool a)
-> Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
-> Bool a
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all (\(TxOut a
c', (TxOut a
_, TxOut a
c, ByteString 256 a
_)) -> TxOut a
c' TxOut a -> TxOut a -> Bool a
forall b a. Eq b a => a -> a -> b
== TxOut a
c) (Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
 -> Bool a)
-> Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
-> Bool a
forall a b. (a -> b) -> a -> b
$ Vector 5 (TxOut a)
-> Vector 5 (TxOut a, TxOut a, ByteString 256 a)
-> Vector 5 (TxOut a, (TxOut a, TxOut a, ByteString 256 a))
forall a b. Vector 5 a -> Vector 5 b -> Vector 5 (a, b)
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
zip Vector 5 (TxOut a)
changes Vector 5 (TxOut a, TxOut a, ByteString 256 a)
transfers

    in Bool a
condition1 Bool a -> Bool a -> Bool a
forall b. BoolType b => b -> b -> b
&& Bool a
condition2 Bool a -> Bool a -> Bool a
forall b. BoolType b => b -> b -> b
&& Bool a
condition3