{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Crypto.JOSE.JWS
  (
  
    JWS
  , GeneralJWS
  , FlattenedJWS
  , CompactJWS
  
  
  
  , newJWSHeader
  , makeJWSHeader
  , signJWS
  
  , verifyJWS
  , verifyJWS'
  , verifyJWSWithPayload
  
  , defaultValidationSettings
  , ValidationSettings
  , ValidationPolicy(..)
  , HasValidationSettings(..)
  , HasAlgorithms(..)
  , HasValidationPolicy(..)
  
  , signatures
  , Signature
  , header
  , signature
  , rawProtectedHeader
  
  , Alg(..)
  , HasJWSHeader(..)
  , JWSHeader
  , module Crypto.JOSE.Error
  , module Crypto.JOSE.Header
  , module Crypto.JOSE.JWK
  ) where
import Control.Applicative ((<|>))
import Data.Foldable (toList)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Monoid ((<>))
import Data.List.NonEmpty (NonEmpty)
import Data.Traversable (traverse)
import Data.Word (Word8)
import Control.Lens hiding ((.=))
import Control.Lens.Cons.Extras (recons)
import Control.Monad.Error.Lens (throwing, throwing_)
import Control.Monad.Except (MonadError, unless)
import Data.Aeson
import qualified Data.ByteString as B
import qualified Data.HashMap.Strict as M
import qualified Data.Set as S
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Crypto.JOSE.Compact
import Crypto.JOSE.Error
import Crypto.JOSE.JWA.JWS
import Crypto.JOSE.JWK
import Crypto.JOSE.JWK.Store
import Crypto.JOSE.Header
import qualified Crypto.JOSE.Types as Types
import qualified Crypto.JOSE.Types.Internal as Types
jwsCritInvalidNames :: [T.Text]
jwsCritInvalidNames = [
  "alg"
  , "jku"
  , "jwk"
  , "x5u"
  , "x5t"
  , "x5t#S256"
  , "x5c"
  , "kid"
  , "typ"
  , "cty"
  , "crit"
  ]
data JWSHeader p = JWSHeader
  { _jwsHeaderAlg :: HeaderParam p Alg
  , _jwsHeaderJku :: Maybe (HeaderParam p Types.URI)  
  , _jwsHeaderJwk :: Maybe (HeaderParam p JWK)
  , _jwsHeaderKid :: Maybe (HeaderParam p T.Text)  
  , _jwsHeaderX5u :: Maybe (HeaderParam p Types.URI)
  , _jwsHeaderX5c :: Maybe (HeaderParam p (NonEmpty Types.SignedCertificate))
  , _jwsHeaderX5t :: Maybe (HeaderParam p Types.Base64SHA1)
  , _jwsHeaderX5tS256 :: Maybe (HeaderParam p Types.Base64SHA256)
  , _jwsHeaderTyp :: Maybe (HeaderParam p T.Text)  
  , _jwsHeaderCty :: Maybe (HeaderParam p T.Text)  
  , _jwsHeaderCrit :: Maybe (NonEmpty T.Text)
  }
  deriving (Eq, Show)
class HasJWSHeader a where
  jwsHeader :: Lens' (a p) (JWSHeader p)
instance HasJWSHeader JWSHeader where
  jwsHeader = id
instance HasJWSHeader a => HasAlg a where
  alg = jwsHeader . \f h@(JWSHeader { _jwsHeaderAlg = a }) ->
    fmap (\a' -> h { _jwsHeaderAlg = a' }) (f a)
instance HasJWSHeader a => HasJku a where
  jku = jwsHeader . \f h@(JWSHeader { _jwsHeaderJku = a }) ->
    fmap (\a' -> h { _jwsHeaderJku = a' }) (f a)
instance HasJWSHeader a => HasJwk a where
  jwk = jwsHeader . \f h@(JWSHeader { _jwsHeaderJwk = a }) ->
    fmap (\a' -> h { _jwsHeaderJwk = a' }) (f a)
instance HasJWSHeader a => HasKid a where
  kid = jwsHeader . \f h@(JWSHeader { _jwsHeaderKid = a }) ->
    fmap (\a' -> h { _jwsHeaderKid = a' }) (f a)
instance HasJWSHeader a => HasX5u a where
  x5u = jwsHeader . \f h@(JWSHeader { _jwsHeaderX5u = a }) ->
    fmap (\a' -> h { _jwsHeaderX5u = a' }) (f a)
instance HasJWSHeader a => HasX5c a where
  x5c = jwsHeader . \f h@(JWSHeader { _jwsHeaderX5c = a }) ->
    fmap (\a' -> h { _jwsHeaderX5c = a' }) (f a)
instance HasJWSHeader a => HasX5t a where
  x5t = jwsHeader . \f h@(JWSHeader { _jwsHeaderX5t = a }) ->
    fmap (\a' -> h { _jwsHeaderX5t = a' }) (f a)
instance HasJWSHeader a => HasX5tS256 a where
  x5tS256 = jwsHeader . \f h@(JWSHeader { _jwsHeaderX5tS256 = a }) ->
    fmap (\a' -> h { _jwsHeaderX5tS256 = a' }) (f a)
instance HasJWSHeader a => HasTyp a where
  typ = jwsHeader . \f h@(JWSHeader { _jwsHeaderTyp = a }) ->
    fmap (\a' -> h { _jwsHeaderTyp = a' }) (f a)
instance HasJWSHeader a => HasCty a where
  cty = jwsHeader . \f h@(JWSHeader { _jwsHeaderCty = a }) ->
    fmap (\a' -> h { _jwsHeaderCty = a' }) (f a)
instance HasJWSHeader a => HasCrit a where
  crit = jwsHeader . \f h@(JWSHeader { _jwsHeaderCrit = a }) ->
    fmap (\a' -> h { _jwsHeaderCrit = a' }) (f a)
newJWSHeader :: (p, Alg) -> (JWSHeader p)
newJWSHeader a = JWSHeader (uncurry HeaderParam a) z z z z z z z z z z
  where z = Nothing
makeJWSHeader
  :: forall e m p. (MonadError e m, AsError e, ProtectionIndicator p)
  => JWK
  -> m (JWSHeader p)
makeJWSHeader k = do
  let
    p = getProtected
    f :: ASetter s t a (Maybe (HeaderParam p a1))
      -> Getting (Maybe a1) JWK (Maybe a1)
      -> s -> t
    f lh lk = set lh (HeaderParam p <$> view lk k)
  algo <- bestJWSAlg k
  pure $ newJWSHeader (p, algo)
    & f kid (jwkKid . to (fmap (view recons)))
    & f x5u jwkX5u
    & f x5c jwkX5c
    & f x5t jwkX5t
    & f x5tS256 jwkX5tS256
data Signature p a = Signature
  (Maybe T.Text)      
  (a p)               
  Types.Base64Octets  
  deriving (Show)
header :: Getter (Signature p a) (a p)
header = to (\(Signature _ h _) -> h)
signature :: (Cons s s Word8 Word8, AsEmpty s) => Getter (Signature p a) s
signature = to (\(Signature _ _ (Types.Base64Octets s)) -> s) . recons
instance (Eq (a p)) => Eq (Signature p a) where
  Signature _ h s == Signature _ h' s' = h == h' && s == s'
instance (HasParams a, ProtectionIndicator p) => FromJSON (Signature p a) where
  parseJSON = withObject "signature" (\o -> Signature
    <$> (Just <$> (o .: "protected" <|> pure ""))  
    <*> do
      hpB64 <- o .:? "protected"
      hp <- maybe
        (pure Nothing)
        (withText "base64url-encoded header params"
          (Types.parseB64Url (maybe
            (fail "protected header contains invalid JSON")
            pure . decode . view recons)))
        hpB64
      hu <- o .:? "header"
      parseParams hp hu
    <*> o .: "signature"
    )
instance (HasParams a, ProtectionIndicator p) => ToJSON (Signature p a) where
  toJSON s@(Signature _ h sig) =
    let
      pro = case rawProtectedHeader s of
        "" -> id
        bs -> ("protected" .= String (T.decodeUtf8 (view recons bs)) :)
      unp = case unprotectedParams h of
        Nothing -> id
        Just o -> ("header" .= o :)
    in
      object $ (pro . unp) [("signature" .= sig)]
instance HasParams JWSHeader where
  parseParamsFor proxy hp hu = JWSHeader
    <$> headerRequired "alg" hp hu
    <*> headerOptional "jku" hp hu
    <*> headerOptional "jwk" hp hu
    <*> headerOptional "kid" hp hu
    <*> headerOptional "x5u" hp hu
    <*> ((fmap . fmap . fmap . fmap)
          (\(Types.Base64X509 cert) -> cert) (headerOptional "x5c" hp hu))
    <*> headerOptional "x5t" hp hu
    <*> headerOptional "x5t#S256" hp hu
    <*> headerOptional "typ" hp hu
    <*> headerOptional "cty" hp hu
    <*> (headerOptionalProtected "crit" hp hu
      >>= parseCrit jwsCritInvalidNames (extensions proxy)
        (fromMaybe mempty hp <> fromMaybe mempty hu))
  params h =
    catMaybes
      [ Just (view (alg . isProtected) h, "alg" .= (view (alg . param) h))
      , fmap (\p -> (view isProtected p, "jku" .= view param p)) (view jku h)
      , fmap (\p -> (view isProtected p, "jwk" .= view param p)) (view jwk h)
      , fmap (\p -> (view isProtected p, "kid" .= view param p)) (view kid h)
      , fmap (\p -> (view isProtected p, "x5u" .= view param p)) (view x5u h)
      , fmap (\p -> (view isProtected p, "x5c" .= fmap Types.Base64X509 (view param p))) (view x5c h)
      , fmap (\p -> (view isProtected p, "x5t" .= view param p)) (view x5t h)
      , fmap (\p -> (view isProtected p, "x5t#S256" .= view param p)) (view x5tS256 h)
      , fmap (\p -> (view isProtected p, "typ" .= view param p)) (view typ h)
      , fmap (\p -> (view isProtected p, "cty" .= view param p)) (view cty h)
      , fmap (\p -> (True, "crit" .= p)) (view crit h)
      ]
data JWS t p a = JWS Types.Base64Octets (t (Signature p a))
type GeneralJWS = JWS [] Protection
type FlattenedJWS = JWS Identity Protection
type CompactJWS = JWS Identity ()
instance (Eq (t (Signature p a))) => Eq (JWS t p a) where
  JWS p sigs == JWS p' sigs' = p == p' && sigs == sigs'
instance (Show (t (Signature p a))) => Show (JWS t p a) where
  show (JWS p sigs) = "JWS " <> show p <> " " <> show sigs
signatures :: Foldable t => Fold (JWS t p a) (Signature p a)
signatures = folding (\(JWS _ sigs) -> sigs)
instance (HasParams a, ProtectionIndicator p) => FromJSON (JWS [] p a) where
  parseJSON v =
    withObject "JWS JSON serialization" (\o -> JWS
      <$> o .: "payload"
      <*> o .: "signatures") v
    <|> fmap (\(JWS p (Identity s)) -> JWS p [s]) (parseJSON v)
instance (HasParams a, ProtectionIndicator p) => FromJSON (JWS Identity p a) where
  parseJSON =
    withObject "Flattened JWS JSON serialization" $ \o ->
      if M.member "signatures" o
      then fail "\"signatures\" member MUST NOT be present"
      else (\p s -> JWS p (pure s)) <$> o .: "payload" <*> parseJSON (Object o)
instance (HasParams a, ProtectionIndicator p) => ToJSON (JWS [] p a) where
  toJSON (JWS p [s]) = object $ "payload" .= p : Types.objectPairs (toJSON s)
  toJSON (JWS p ss) = object ["payload" .= p, "signatures" .= ss]
instance (HasParams a, ProtectionIndicator p) => ToJSON (JWS Identity p a) where
  toJSON (JWS p (Identity s)) = object $ "payload" .= p : Types.objectPairs (toJSON s)
signingInput
  :: (HasParams a, ProtectionIndicator p)
  => Signature p a
  -> Types.Base64Octets
  -> B.ByteString
signingInput sig (Types.Base64Octets p) =
  rawProtectedHeader sig <> "." <> review Types.base64url p
rawProtectedHeader
  :: (HasParams a, ProtectionIndicator p)
  => Signature p a -> B.ByteString
rawProtectedHeader (Signature raw h _) =
  maybe (view recons $ protectedParamsEncoded h) T.encodeUtf8 raw
instance HasParams a => ToCompact (JWS Identity () a) where
  toCompact (JWS p (Identity s@(Signature _ _ (Types.Base64Octets sig)))) =
    [ view recons $ signingInput s p
    , review Types.base64url sig
    ]
instance HasParams a => FromCompact (JWS Identity () a) where
  fromCompact xs = case xs of
    [h, p, s] -> do
      (h', p', s') <- (,,) <$> t 0 h <*> t 1 p <*> t 2 s
      let o = object [ ("payload", p'), ("protected", h'), ("signature", s') ]
      case fromJSON o of
        Error e -> throwing _JSONDecodeError e
        Success a -> pure a
    xs' -> throwing (_CompactDecodeError . _CompactInvalidNumberOfParts)
            (InvalidNumberOfParts 3 (fromIntegral (length xs')))
    where
      l = _CompactDecodeError . _CompactInvalidText
      t n = either (throwing l . CompactTextError n) (pure . String)
        . T.decodeUtf8' . view recons
signJWS
  :: ( Cons s s Word8 Word8
     , HasJWSHeader a, HasParams a, MonadRandom m, AsError e, MonadError e m
     , Traversable t
     , ProtectionIndicator p
     )
  => s          
  -> t ((a p), JWK) 
  -> m (JWS t p a)
signJWS s =
  let s' = view recons s
  in fmap (JWS (Types.Base64Octets s')) . traverse (uncurry (mkSignature s'))
mkSignature
  :: ( HasJWSHeader a, HasParams a, MonadRandom m, AsError e, MonadError e m
     , ProtectionIndicator p
     )
  => B.ByteString -> a p -> JWK -> m (Signature p a)
mkSignature p h k =
  let
    almostSig = Signature Nothing h . Types.Base64Octets
  in
    almostSig
    <$> sign
          (view (alg . param) h)
          (k ^. jwkMaterial)
          (signingInput (almostSig "") (Types.Base64Octets p))
data ValidationPolicy
  = AnyValidated
  
  | AllValidated
  
  
  deriving (Eq)
data ValidationSettings = ValidationSettings
  (S.Set Alg)
  ValidationPolicy
class HasValidationSettings a where
  validationSettings :: Lens' a ValidationSettings
  validationSettingsAlgorithms :: Lens' a (S.Set Alg)
  validationSettingsAlgorithms = validationSettings . go where
    go f (ValidationSettings algs pol) =
      (\algs' -> ValidationSettings algs' pol) <$> f algs
  validationSettingsValidationPolicy :: Lens' a ValidationPolicy
  validationSettingsValidationPolicy = validationSettings . go where
    go f (ValidationSettings algs pol) =
      (\pol' -> ValidationSettings algs pol') <$> f pol
instance HasValidationSettings ValidationSettings where
  validationSettings = id
class HasAlgorithms s where
  algorithms :: Lens' s (S.Set Alg)
class HasValidationPolicy s where
  validationPolicy :: Lens' s ValidationPolicy
instance HasValidationSettings a => HasAlgorithms a where
  algorithms = validationSettingsAlgorithms
instance HasValidationSettings a => HasValidationPolicy a where
  validationPolicy = validationSettingsValidationPolicy
defaultValidationSettings :: ValidationSettings
defaultValidationSettings = ValidationSettings
  ( S.fromList
    [ HS256, HS384, HS512
    , RS256, RS384, RS512
    , ES256, ES384, ES512
    , PS256, PS384, PS512
    , EdDSA
    ] )
  AllValidated
verifyJWS'
  ::  ( AsError e, MonadError e m , HasJWSHeader h, HasParams h
      , VerificationKeyStore m (h p) s k
      , Cons s s Word8 Word8, AsEmpty s
      , Foldable t
      , ProtectionIndicator p
      )
  => k      
  -> JWS t p h  
  -> m s
verifyJWS' = verifyJWS defaultValidationSettings
verifyJWS
  ::  ( HasAlgorithms a, HasValidationPolicy a, AsError e, MonadError e m
      , HasJWSHeader h, HasParams h
      , VerificationKeyStore m (h p) s k
      , Cons s s Word8 Word8, AsEmpty s
      , Foldable t
      , ProtectionIndicator p
      )
  => a        
  -> k        
  -> JWS t p h  
  -> m s
verifyJWS = verifyJWSWithPayload pure
verifyJWSWithPayload
  ::  ( HasAlgorithms a, HasValidationPolicy a, AsError e, MonadError e m
      , HasJWSHeader h, HasParams h
      , VerificationKeyStore m (h p) payload k
      , Cons s s Word8 Word8, AsEmpty s
      , Foldable t
      , ProtectionIndicator p
      )
  => (s -> m payload)  
  -> a                 
  -> k                 
  -> JWS t p h         
  -> m payload
verifyJWSWithPayload dec conf k (JWS p@(Types.Base64Octets p') sigs) =
  let
    algs :: S.Set Alg
    algs = conf ^. algorithms
    policy :: ValidationPolicy
    policy = conf ^. validationPolicy
    shouldValidateSig = (`elem` algs) . view (header . alg . param)
    applyPolicy AnyValidated xs = unless (or xs) (throwing_ _JWSNoValidSignatures)
    applyPolicy AllValidated [] = throwing_ _JWSNoSignatures
    applyPolicy AllValidated xs = unless (and xs) (throwing_ _JWSInvalidSignature)
    validate payload sig = do
      keys <- getVerificationKeys (view header sig) payload k
      if null keys
        then throwing_ _NoUsableKeys
        else pure $ any ((== Right True) . verifySig p sig) keys
  in do
    payload <- (dec . view recons) p'
    results <- traverse (validate payload) $ filter shouldValidateSig $ toList sigs
    payload <$ applyPolicy policy results
verifySig
  :: (HasJWSHeader a, HasParams a, ProtectionIndicator p)
  => Types.Base64Octets
  -> Signature p a
  -> JWK
  -> Either Error Bool
verifySig msg sig@(Signature _ h (Types.Base64Octets s)) k =
  verify (view (alg . param) h) (view jwkMaterial k) tbs s
  where
  tbs = signingInput sig msg