{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.JOSE.Header
  (
  
    HeaderParam(..)
  , ProtectionIndicator(..)
  , Protection(..)
  , protection
  , isProtected
  , param
  
  
  , HasParams(..)
  , headerRequired
  , headerRequiredProtected
  , headerOptional
  , headerOptionalProtected
  
  , parseParams
  , parseCrit
  
  , protectedParamsEncoded
  , unprotectedParams
  
  , HasAlg(..)
  , HasJku(..)
  , HasJwk(..)
  , HasKid(..)
  , HasX5u(..)
  , HasX5c(..)
  , HasX5t(..)
  , HasX5tS256(..)
  , HasTyp(..)
  , HasCty(..)
  , HasCrit(..)
  ) where
import qualified Control.Monad.Fail as Fail
import Data.List.NonEmpty (NonEmpty)
import Data.Monoid ((<>))
import Data.Proxy (Proxy(..))
import Control.Lens (Lens', Getter, review, to)
import Data.Aeson (FromJSON(..), Object, Value, encode, object)
import Data.Aeson.Types (Pair, Parser)
import qualified Data.ByteString.Lazy as L
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T
import qualified Crypto.JOSE.JWA.JWS as JWA.JWS
import Crypto.JOSE.JWK (JWK)
import Crypto.JOSE.Types.Orphans ()
import Crypto.JOSE.Types.Internal (base64url)
import qualified Crypto.JOSE.Types as Types
class HasParams (a :: * -> *) where
  
  
  params :: ProtectionIndicator p => a p -> [(Bool, Pair)]
  
  
  extensions :: Proxy a -> [T.Text]
  extensions = const []
  parseParamsFor
    :: (HasParams b, ProtectionIndicator p)
    => Proxy b -> Maybe Object -> Maybe Object -> Parser (a p)
parseParams
  :: forall a p. (HasParams a, ProtectionIndicator p)
  => Maybe Object 
  -> Maybe Object 
  -> Parser (a p)
parseParams = parseParamsFor (Proxy :: Proxy a)
protectedParams
  :: (HasParams a, ProtectionIndicator p)
  => a p -> Maybe Value 
protectedParams h =
  case (map snd . filter fst . params) h of
    [] -> Nothing
    xs -> Just (object xs)
protectedParamsEncoded
  :: (HasParams a, ProtectionIndicator p)
  => a p -> L.ByteString
protectedParamsEncoded =
  maybe mempty (review base64url . encode) . protectedParams
unprotectedParams
  :: (HasParams a, ProtectionIndicator p)
  => a p -> Maybe Value 
unprotectedParams h =
  case (map snd . filter (not . fst) . params) h of
    [] -> Nothing
    xs -> Just (object xs)
data Protection = Protected | Unprotected
  deriving (Eq, Show)
class Eq a => ProtectionIndicator a where
  
  getProtected :: a
  
  
  getUnprotected :: Maybe a
instance ProtectionIndicator Protection where
  getProtected = Protected
  getUnprotected = Just Unprotected
instance ProtectionIndicator () where
  getProtected = ()
  getUnprotected = Nothing
data HeaderParam p a = HeaderParam p a
  deriving (Eq, Show)
instance Functor (HeaderParam p) where
  fmap f (HeaderParam p a) = HeaderParam p (f a)
protection :: Lens' (HeaderParam p a) p
protection f (HeaderParam p v) = fmap (\p' -> HeaderParam p' v) (f p)
param :: Lens' (HeaderParam p a) a
param f (HeaderParam p v) = fmap (\v' -> HeaderParam p v') (f v)
isProtected :: (ProtectionIndicator p) => Getter (HeaderParam p a) Bool
isProtected = protection . to (== getProtected)
headerOptional
  :: (FromJSON a, ProtectionIndicator p)
  => T.Text
  -> Maybe Object
  -> Maybe Object
  -> Parser (Maybe (HeaderParam p a))
headerOptional k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
  (Just _, Just _)    -> fail $ "duplicate header " ++ show k
  (Just v, Nothing)   -> Just . HeaderParam getProtected <$> parseJSON v
  (Nothing, Just v)   -> maybe
    (fail "unprotected header not supported")
    (\p -> Just . HeaderParam p <$> parseJSON v)
    getUnprotected
  (Nothing, Nothing)  -> pure Nothing
headerOptionalProtected
  :: FromJSON a
  => T.Text
  -> Maybe Object
  -> Maybe Object
  -> Parser (Maybe a)
headerOptionalProtected k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
  (Just _, Just _)    -> fail $ "duplicate header " ++ show k
  (_, Just _) -> fail $ "header must be protected: " ++ show k
  (Just v, _) -> Just <$> parseJSON v
  _           -> pure Nothing
headerRequired
  :: (FromJSON a, ProtectionIndicator p)
  => T.Text
  -> Maybe Object
  -> Maybe Object
  -> Parser (HeaderParam p a)
headerRequired k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
  (Just _, Just _)    -> fail $ "duplicate header " ++ show k
  (Just v, Nothing)   -> HeaderParam getProtected <$> parseJSON v
  (Nothing, Just v)   -> maybe
    (fail "unprotected header not supported")
    (\p -> HeaderParam p <$> parseJSON v)
    getUnprotected
  (Nothing, Nothing)  -> fail $ "missing required header " ++ show k
headerRequiredProtected
  :: FromJSON a
  => T.Text
  -> Maybe Object
  -> Maybe Object
  -> Parser a
headerRequiredProtected k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
  (Just _, Just _)    -> fail $ "duplicate header " ++ show k
  (_, Just _) -> fail $ "header must be protected: " <> show k
  (Just v, _) -> parseJSON v
  _           -> fail $ "missing required protected header: " <> show k
critObjectParser
  :: (Foldable t0, Foldable t1, Fail.MonadFail m)
  => t0 T.Text -> t1 T.Text -> Object -> T.Text -> m T.Text
critObjectParser reserved exts o s
  | s `elem` reserved         = Fail.fail "crit key is reserved"
  | s `notElem` exts          = Fail.fail "crit key is not understood"
  | not (s `M.member` o)      = Fail.fail "crit key is not present in headers"
  | otherwise                 = pure s
parseCrit
  :: (Foldable t0, Foldable t1, Traversable t2, Traversable t3, Fail.MonadFail m)
  => t0 T.Text 
  -> t1 T.Text 
  -> Object    
  -> t2 (t3 T.Text) 
  -> m (t2 (t3 T.Text))
parseCrit reserved exts o = mapM (mapM (critObjectParser reserved exts o))
  
class HasAlg a where
  alg :: Lens' (a p) (HeaderParam p JWA.JWS.Alg)
class HasJku a where
  jku :: Lens' (a p) (Maybe (HeaderParam p Types.URI))
class HasJwk a where
  jwk :: Lens' (a p) (Maybe (HeaderParam p JWK))
class HasKid a where
  kid :: Lens' (a p) (Maybe (HeaderParam p T.Text))
class HasX5u a where
  x5u :: Lens' (a p) (Maybe (HeaderParam p Types.URI))
class HasX5c a where
  x5c :: Lens' (a p) (Maybe (HeaderParam p (NonEmpty Types.SignedCertificate)))
class HasX5t a where
  x5t :: Lens' (a p) (Maybe (HeaderParam p Types.Base64SHA1))
class HasX5tS256 a where
  x5tS256 :: Lens' (a p) (Maybe (HeaderParam p Types.Base64SHA256))
class HasTyp a where
  typ :: Lens' (a p) (Maybe (HeaderParam p T.Text))
class HasCty a where
  cty :: Lens' (a p) (Maybe (HeaderParam p T.Text))
class HasCrit a where
  crit :: Lens' (a p) (Maybe (NonEmpty T.Text))