{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE TypeApplications #-} module Tests.Symbolic.Algorithm.JWT (specJWT) where import Codec.Crypto.RSA (generateKeyPair) import qualified Codec.Crypto.RSA as R import Data.Function (($)) import GHC.Generics (Par1 (..)) import Prelude (pure) import qualified Prelude as P import System.Random (mkStdGen) import Test.Hspec (Spec, describe) import Test.QuickCheck (Gen, arbitrary, withMaxSuccess, (.&.), (===)) import Tests.Symbolic.ArithmeticCircuit (it) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (Fr) import ZkFold.Prelude (chooseNatural) import ZkFold.Symbolic.Algorithms.RSA import ZkFold.Symbolic.Data.Bool import ZkFold.Symbolic.Data.JWT import ZkFold.Symbolic.Data.JWT.Google import ZkFold.Symbolic.Data.JWT.RS256 import ZkFold.Symbolic.Data.JWT.Twitch import ZkFold.Symbolic.Data.VarByteString (VarByteString, fromNatural) import ZkFold.Symbolic.Interpreter (Interpreter (Interpreter)) type I = Interpreter Fr toss :: Natural -> Gen Natural toss x = chooseNatural (0, x -! 1) evalBool :: forall a . Bool (Interpreter a) -> a evalBool (Bool (Interpreter (Par1 v))) = v specJWT :: Spec specJWT = do describe "JWT sign and verify" $ do it "signs and verifies correctly" $ withMaxSuccess 10 $ do x <- toss $ (2 :: Natural) ^ (32 :: Natural) kidBits <- toss $ (2 :: Natural) ^ (320 :: Natural) let gen = mkStdGen (P.fromIntegral x) (R.PublicKey{..}, R.PrivateKey{..}, _) = generateKeyPair gen 2048 prvkey = PrivateKey (fromConstant private_d) (fromConstant private_n) pubkey = PublicKey (fromConstant public_e) (fromConstant public_n) kid = fromNatural 320 kidBits :: VarByteString 320 I skey = SigningKey kid prvkey cert = Certificate kid pubkey (payloadG :: GooglePayload I) <- arbitrary (payloadT :: TwitchPayload I) <- arbitrary let (headerG, sigG) = signPayload @"RS256" payloadG skey (checkG, _) = verifyJWT @"RS256" headerG payloadG sigG cert let (headerT, sigT) = signPayload @"RS256" payloadT skey (checkT, _) = verifyJWT @"RS256" headerT payloadT sigT cert pure $ evalBool checkG === one .&. evalBool checkT === one