{-# LANGUAGE DeriveAnyClass       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Algorithms.RSA
    ( sign
    , verify
    , RSA
    , PublicKey (..)
    , PrivateKey (..)
    , Signature
    , KeyLength
    ) where

import           Control.DeepSeq                      (NFData, force)
import           GHC.Generics                         (Generic)
import           Prelude                              (($))
import qualified Prelude                              as P

import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector              (Vector)
import           ZkFold.Symbolic.Algorithms.Hash.SHA2 (SHA2, sha2)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool            (Bool, (&&))
import           ZkFold.Symbolic.Data.ByteString      (ByteString)
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators     (Ceil, GetRegisterSize, Iso (..), KnownRegisters,
                                                       NumberOfRegisters, RegisterSize (..), Resize (..))
import           ZkFold.Symbolic.Data.Eq
import           ZkFold.Symbolic.Data.Input           (SymbolicInput, isValid)
import           ZkFold.Symbolic.Data.UInt            (OrdWord, UInt, expMod)

type KeyLength = 512

type Signature ctx = ByteString KeyLength ctx

data PrivateKey ctx
    = PrivateKey
        { forall (ctx :: (Type -> Type) -> Type).
PrivateKey ctx -> UInt KeyLength 'Auto ctx
prvD :: UInt KeyLength 'Auto ctx
        , forall (ctx :: (Type -> Type) -> Type).
PrivateKey ctx -> UInt KeyLength 'Auto ctx
prvN :: UInt KeyLength 'Auto ctx
        }

deriving instance Generic (PrivateKey context)
deriving instance (NFData (context (Vector (NumberOfRegisters (BaseField context) KeyLength 'Auto)))) => NFData (PrivateKey context)
deriving instance (P.Eq (context (Vector (NumberOfRegisters (BaseField context) KeyLength 'Auto))))   => P.Eq   (PrivateKey context)
deriving instance
    ( P.Show (BaseField context)
    , P.Show (context (Vector (NumberOfRegisters (BaseField context) KeyLength 'Auto)))
    ) => P.Show (PrivateKey context)

deriving instance (Symbolic ctx, KnownRegisters ctx KeyLength 'Auto) => SymbolicData (PrivateKey ctx)

instance
  ( Symbolic ctx
  , KnownRegisters ctx KeyLength 'Auto
  ) => SymbolicInput (PrivateKey ctx) where
    isValid :: PrivateKey ctx -> Bool (Context (PrivateKey ctx))
isValid PrivateKey{UInt KeyLength 'Auto ctx
prvD :: forall (ctx :: (Type -> Type) -> Type).
PrivateKey ctx -> UInt KeyLength 'Auto ctx
prvN :: forall (ctx :: (Type -> Type) -> Type).
PrivateKey ctx -> UInt KeyLength 'Auto ctx
prvD :: UInt KeyLength 'Auto ctx
prvN :: UInt KeyLength 'Auto ctx
..} = UInt KeyLength 'Auto ctx
-> Bool (Context (UInt KeyLength 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt KeyLength 'Auto ctx
prvD Bool ctx -> Bool ctx -> Bool ctx
forall b. BoolType b => b -> b -> b
&& UInt KeyLength 'Auto ctx
-> Bool (Context (UInt KeyLength 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt KeyLength 'Auto ctx
prvN

data PublicKey ctx
    = PublicKey
        { forall (ctx :: (Type -> Type) -> Type).
PublicKey ctx -> UInt 32 'Auto ctx
pubE :: UInt 32 'Auto ctx
        , forall (ctx :: (Type -> Type) -> Type).
PublicKey ctx -> UInt KeyLength 'Auto ctx
pubN :: UInt KeyLength 'Auto ctx
        }

deriving instance Generic (PublicKey context)
deriving instance
    ( NFData (context (Vector (NumberOfRegisters (BaseField context) KeyLength 'Auto)))
    , NFData (context (Vector (NumberOfRegisters (BaseField context) 32 'Auto)))
    ) =>  NFData  (PublicKey context)
deriving instance
    ( P.Eq (context (Vector (NumberOfRegisters (BaseField context) KeyLength 'Auto)))
    , P.Eq (context (Vector (NumberOfRegisters (BaseField context) 32 'Auto)))
    ) =>  P.Eq    (PublicKey context)
deriving instance
    ( P.Show (context (Vector (NumberOfRegisters (BaseField context) KeyLength 'Auto)))
    , P.Show (context (Vector (NumberOfRegisters (BaseField context) 32 'Auto)))
    , P.Show (BaseField context)
    ) =>  P.Show  (PublicKey context)

deriving instance
    ( Symbolic ctx
    , KnownRegisters ctx 32 'Auto
    , KnownRegisters ctx KeyLength 'Auto
    ) => SymbolicData (PublicKey ctx)

instance
  ( Symbolic ctx
  , KnownRegisters ctx 32 'Auto
  , KnownRegisters ctx KeyLength 'Auto
  ) => SymbolicInput (PublicKey ctx) where
    isValid :: PublicKey ctx -> Bool (Context (PublicKey ctx))
isValid PublicKey{UInt 32 'Auto ctx
UInt KeyLength 'Auto ctx
pubE :: forall (ctx :: (Type -> Type) -> Type).
PublicKey ctx -> UInt 32 'Auto ctx
pubN :: forall (ctx :: (Type -> Type) -> Type).
PublicKey ctx -> UInt KeyLength 'Auto ctx
pubE :: UInt 32 'Auto ctx
pubN :: UInt KeyLength 'Auto ctx
..} = UInt 32 'Auto ctx -> Bool (Context (UInt 32 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt 32 'Auto ctx
pubE Bool ctx -> Bool ctx -> Bool ctx
forall b. BoolType b => b -> b -> b
&& UInt KeyLength 'Auto ctx
-> Bool (Context (UInt KeyLength 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt KeyLength 'Auto ctx
pubN

type RSA ctx msgLen =
   ( SHA2 "SHA256" ctx msgLen
   , KnownRegisters ctx KeyLength 'Auto
   , KnownRegisters ctx (2 * KeyLength) 'Auto
   , KnownNat (Ceil (GetRegisterSize (BaseField ctx) (2 * KeyLength) 'Auto) OrdWord)
   , NFData (ctx (Vector KeyLength))
   , NFData (ctx (Vector (NumberOfRegisters (BaseField ctx) KeyLength 'Auto)))
   , NFData (ctx (Vector (NumberOfRegisters (BaseField ctx) (2 * KeyLength) 'Auto)))
   )

sign
    :: forall ctx msgLen
    .  RSA ctx msgLen
    => ByteString msgLen ctx
    -> PrivateKey ctx
    -> Signature ctx
sign :: forall (ctx :: (Type -> Type) -> Type) (msgLen :: Natural).
RSA ctx msgLen =>
ByteString msgLen ctx -> PrivateKey ctx -> Signature ctx
sign ByteString msgLen ctx
msg PrivateKey{UInt KeyLength 'Auto ctx
prvD :: forall (ctx :: (Type -> Type) -> Type).
PrivateKey ctx -> UInt KeyLength 'Auto ctx
prvN :: forall (ctx :: (Type -> Type) -> Type).
PrivateKey ctx -> UInt KeyLength 'Auto ctx
prvD :: UInt KeyLength 'Auto ctx
prvN :: UInt KeyLength 'Auto ctx
..} = Signature ctx -> Signature ctx
forall a. NFData a => a -> a
force (Signature ctx -> Signature ctx) -> Signature ctx -> Signature ctx
forall a b. (a -> b) -> a -> b
$ UInt KeyLength 'Auto ctx -> Signature ctx
forall a b. Iso a b => a -> b
from (UInt KeyLength 'Auto ctx -> Signature ctx)
-> UInt KeyLength 'Auto ctx -> Signature ctx
forall a b. (a -> b) -> a -> b
$ UInt 256 'Auto ctx
-> UInt KeyLength 'Auto ctx
-> UInt KeyLength 'Auto ctx
-> UInt KeyLength 'Auto ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural) (p :: Natural)
       (m :: Natural) (r :: RegisterSize).
(Symbolic c, KnownRegisterSize r, KnownNat p, KnownNat n,
 KnownNat m, KnownNat (2 * m), KnownRegisters c (2 * m) r,
 KnownNat (Ceil (GetRegisterSize (BaseField c) (2 * m) r) 16),
 NFData (c (Vector (NumberOfRegisters (BaseField c) (2 * m) r)))) =>
UInt n r c -> UInt p r c -> UInt m r c -> UInt m r c
expMod UInt 256 'Auto ctx
msgI UInt KeyLength 'Auto ctx
prvD UInt KeyLength 'Auto ctx
prvN
    where
        h :: ByteString 256 ctx
        h :: ByteString 256 ctx
h = forall (algorithm :: Symbol) (context :: (Type -> Type) -> Type)
       (k :: Natural) {d :: Natural}.
(SHA2 algorithm context k,
 d
 ~ Div
     (PaddedLength k (ChunkSize algorithm) (2 * WordSize algorithm))
     (ChunkSize algorithm)) =>
ByteString k context -> ByteString (ResultSize algorithm) context
sha2 @"SHA256" ByteString msgLen ctx
msg

        msgI :: UInt 256 'Auto ctx
        msgI :: UInt 256 'Auto ctx
msgI = ByteString 256 ctx -> UInt 256 'Auto ctx
forall a b. Iso a b => a -> b
from ByteString 256 ctx
h

verify
    :: forall ctx msgLen
    .  RSA ctx msgLen
    => ByteString msgLen ctx
    -> Signature ctx
    -> PublicKey ctx
    -> Bool ctx
verify :: forall (ctx :: (Type -> Type) -> Type) (msgLen :: Natural).
RSA ctx msgLen =>
ByteString msgLen ctx -> Signature ctx -> PublicKey ctx -> Bool ctx
verify ByteString msgLen ctx
msg Signature ctx
sig PublicKey{UInt 32 'Auto ctx
UInt KeyLength 'Auto ctx
pubE :: forall (ctx :: (Type -> Type) -> Type).
PublicKey ctx -> UInt 32 'Auto ctx
pubN :: forall (ctx :: (Type -> Type) -> Type).
PublicKey ctx -> UInt KeyLength 'Auto ctx
pubE :: UInt 32 'Auto ctx
pubN :: UInt KeyLength 'Auto ctx
..} = UInt KeyLength 'Auto ctx
target UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx -> Bool ctx
forall b a. Eq b a => a -> a -> b
== UInt KeyLength 'Auto ctx
input
    where
        h :: ByteString 256 ctx
        h :: ByteString 256 ctx
h = forall (algorithm :: Symbol) (context :: (Type -> Type) -> Type)
       (k :: Natural) {d :: Natural}.
(SHA2 algorithm context k,
 d
 ~ Div
     (PaddedLength k (ChunkSize algorithm) (2 * WordSize algorithm))
     (ChunkSize algorithm)) =>
ByteString k context -> ByteString (ResultSize algorithm) context
sha2 @"SHA256" ByteString msgLen ctx
msg

        target :: UInt KeyLength 'Auto ctx
        target :: UInt KeyLength 'Auto ctx
target = UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx
forall a. NFData a => a -> a
force (UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx)
-> UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx
forall a b. (a -> b) -> a -> b
$ UInt KeyLength 'Auto ctx
-> UInt 32 'Auto ctx
-> UInt KeyLength 'Auto ctx
-> UInt KeyLength 'Auto ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural) (p :: Natural)
       (m :: Natural) (r :: RegisterSize).
(Symbolic c, KnownRegisterSize r, KnownNat p, KnownNat n,
 KnownNat m, KnownNat (2 * m), KnownRegisters c (2 * m) r,
 KnownNat (Ceil (GetRegisterSize (BaseField c) (2 * m) r) 16),
 NFData (c (Vector (NumberOfRegisters (BaseField c) (2 * m) r)))) =>
UInt n r c -> UInt p r c -> UInt m r c -> UInt m r c
expMod (Signature ctx -> UInt KeyLength 'Auto ctx
forall a b. Iso a b => a -> b
from Signature ctx
sig :: UInt KeyLength 'Auto ctx) UInt 32 'Auto ctx
pubE UInt KeyLength 'Auto ctx
pubN

        input :: UInt KeyLength 'Auto ctx
        input :: UInt KeyLength 'Auto ctx
input = UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx
forall a. NFData a => a -> a
force (UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx)
-> UInt KeyLength 'Auto ctx -> UInt KeyLength 'Auto ctx
forall a b. (a -> b) -> a -> b
$ UInt 256 'Auto ctx -> UInt KeyLength 'Auto ctx
forall a b. Resize a b => a -> b
resize (ByteString 256 ctx -> UInt 256 'Auto ctx
forall a b. Iso a b => a -> b
from ByteString 256 ctx
h :: UInt 256 'Auto ctx)