-- Copyright (C) 2014-2022  Fraser Tweedale
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}

{-|

JOSE error types and helpers.

-}
module Crypto.JOSE.Error
  (
  -- * Running JOSE computations
    runJOSE
  , unwrapJOSE
  , JOSE(..)

  -- * Base error type and class
  , Error(..)
  , AsError(..)

  -- * JOSE compact serialisation errors
  , InvalidNumberOfParts(..), expectedParts, actualParts
  , CompactTextError(..)
  , CompactDecodeError(..)
  , _CompactInvalidNumberOfParts
  , _CompactInvalidText

  ) where

import Numeric.Natural

import Control.Monad.Except (MonadError(..), ExceptT, runExceptT)
import Control.Monad.Trans (MonadIO(liftIO), MonadTrans(lift))
import qualified Crypto.PubKey.RSA as RSA
import Crypto.Error (CryptoError)
import Crypto.Random (MonadRandom(..))
import Control.Lens (Getter, to)
import Control.Lens.TH (makeClassyPrisms, makePrisms)
import qualified Data.Text as T
import qualified Data.Text.Encoding.Error as T


-- | The wrong number of parts were found when decoding a
-- compact JOSE object.
--
data InvalidNumberOfParts =
  InvalidNumberOfParts Natural Natural -- ^ expected vs actual parts
  deriving (InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
Eq)

instance Show InvalidNumberOfParts where
  show :: InvalidNumberOfParts -> String
show (InvalidNumberOfParts Natural
n Natural
m) =
    String
"Expected " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Natural
n forall a. Semigroup a => a -> a -> a
<> String
" parts; got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Natural
m

-- | Get the expected or actual number of parts.
expectedParts, actualParts :: Getter InvalidNumberOfParts Natural
expectedParts :: Getter InvalidNumberOfParts Natural
expectedParts = forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
n Natural
_) -> Natural
n
actualParts :: Getter InvalidNumberOfParts Natural
actualParts   = forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
_ Natural
n) -> Natural
n


-- | Bad UTF-8 data in a compact object, at the specified index
data CompactTextError = CompactTextError
  Natural
  T.UnicodeException
  deriving (CompactTextError -> CompactTextError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompactTextError -> CompactTextError -> Bool
$c/= :: CompactTextError -> CompactTextError -> Bool
== :: CompactTextError -> CompactTextError -> Bool
$c== :: CompactTextError -> CompactTextError -> Bool
Eq)

instance Show CompactTextError where
  show :: CompactTextError -> String
show (CompactTextError Natural
n UnicodeException
s) =
    String
"Invalid text at part " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Natural
n forall a. Semigroup a => a -> a -> a
<> String
": " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show UnicodeException
s


-- | An error when decoding a JOSE compact object.
-- JSON decoding errors that occur during compact object processing
-- throw 'JSONDecodeError'.
--
data CompactDecodeError
  = CompactInvalidNumberOfParts InvalidNumberOfParts
  | CompactInvalidText CompactTextError
  deriving (CompactDecodeError -> CompactDecodeError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompactDecodeError -> CompactDecodeError -> Bool
$c/= :: CompactDecodeError -> CompactDecodeError -> Bool
== :: CompactDecodeError -> CompactDecodeError -> Bool
$c== :: CompactDecodeError -> CompactDecodeError -> Bool
Eq)
makePrisms ''CompactDecodeError

instance Show CompactDecodeError where
  show :: CompactDecodeError -> String
show (CompactInvalidNumberOfParts InvalidNumberOfParts
e) = String
"Invalid number of parts: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show InvalidNumberOfParts
e
  show (CompactInvalidText CompactTextError
e) = String
"Invalid text: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show CompactTextError
e



-- | All the errors that can occur.
--
data Error
  = AlgorithmNotImplemented   -- ^ A requested algorithm is not implemented
  | AlgorithmMismatch String  -- ^ A requested algorithm cannot be used
  | KeyMismatch T.Text        -- ^ Wrong type of key was given
  | KeySizeTooSmall           -- ^ Key size is too small
  | OtherPrimesNotSupported   -- ^ RSA private key with >2 primes not supported
  | RSAError RSA.Error        -- ^ RSA encryption, decryption or signing error
  | CryptoError CryptoError   -- ^ Various cryptonite library error cases
  | CompactDecodeError CompactDecodeError
  -- ^ Wrong number of parts in compact serialisation
  | JSONDecodeError String    -- ^ JSON (Aeson) decoding error
  | NoUsableKeys              -- ^ No usable keys were found in the key store
  | JWSCritUnprotected
  | JWSNoValidSignatures
  -- ^ 'AnyValidated' policy active, and no valid signature encountered
  | JWSInvalidSignature
  -- ^ 'AllValidated' policy active, and invalid signature encountered
  | JWSNoSignatures
  -- ^ 'AllValidated' policy active, and there were no signatures on object
  --   that matched the allowed algorithms
  deriving (Error -> Error -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)
makeClassyPrisms ''Error


newtype JOSE e m a = JOSE (ExceptT e m a)

-- | Run the 'JOSE' computation.  Result is an @Either e a@
-- where @e@ is the error type (typically 'Error' or 'Crypto.JWT.JWTError')
runJOSE :: JOSE e m a -> m (Either e a)
runJOSE :: forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
runJOSE = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(JOSE ExceptT e m a
m) -> ExceptT e m a
m)

-- | Get the inner 'ExceptT' value of the 'JOSE' computation.
-- Typically 'runJOSE' would be preferred, unless you specifically
-- need an 'ExceptT' value.
unwrapJOSE :: JOSE e m a -> ExceptT e m a
unwrapJOSE :: forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE (JOSE ExceptT e m a
m) = ExceptT e m a
m


instance (Functor m) => Functor (JOSE e m) where
  fmap :: forall a b. (a -> b) -> JOSE e m a -> JOSE e m b
fmap a -> b
f (JOSE ExceptT e m a
ma) = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f ExceptT e m a
ma)

instance (Monad m) => Applicative (JOSE e m) where
  pure :: forall a. a -> JOSE e m a
pure = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure
  JOSE ExceptT e m (a -> b)
mf <*> :: forall a b. JOSE e m (a -> b) -> JOSE e m a -> JOSE e m b
<*> JOSE ExceptT e m a
ma = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m (a -> b)
mf forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ExceptT e m a
ma)

instance (Monad m) => Monad (JOSE e m) where
  JOSE ExceptT e m a
ma >>= :: forall a b. JOSE e m a -> (a -> JOSE e m b) -> JOSE e m b
>>= a -> JOSE e m b
f = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a
ma forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> JOSE e m b
f)

instance MonadTrans (JOSE e) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> JOSE e m a
lift = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance (Monad m) => MonadError e (JOSE e m) where
  throwError :: forall a. e -> JOSE e m a
throwError = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: forall a. JOSE e m a -> (e -> JOSE e m a) -> JOSE e m a
catchError (JOSE ExceptT e m a
m) e -> JOSE e m a
handle = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError ExceptT e m a
m (forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> JOSE e m a
handle))

instance (MonadIO m) => MonadIO (JOSE e m) where
  liftIO :: forall a. IO a -> JOSE e m a
liftIO = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance (MonadRandom m) => MonadRandom (JOSE e m) where
    getRandomBytes :: forall byteArray. ByteArray byteArray => Int -> JOSE e m byteArray
getRandomBytes = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes