-- |
-- Module      : Crypto.PubKey.Rabin.OAEP
-- License     : BSD-style
-- Maintainer  : Carlos Rodriguez-Vega <crodveg@yahoo.es>
-- Stability   : experimental
-- Portability : unknown
--
-- OAEP padding scheme.
-- See <http://en.wikipedia.org/wiki/Optimal_asymmetric_encryption_padding>.
--
module Crypto.PubKey.Rabin.OAEP
    ( OAEPParams(..)
    , defaultOAEPParams
    , pad
    , unpad
    ) where
        
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.Bits (xor)

import           Crypto.Hash
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray)
import qualified Crypto.Internal.ByteArray as B (convert)
import           Crypto.PubKey.MaskGenFunction
import           Crypto.PubKey.Internal (and')
import           Crypto.PubKey.Rabin.Types

-- | Parameters for OAEP padding.
data OAEPParams hash seed output = OAEPParams
    { forall hash seed output. OAEPParams hash seed output -> hash
oaepHash       :: hash                            -- ^ hash function to use
    , forall hash seed output.
OAEPParams hash seed output -> MaskGenAlgorithm seed output
oaepMaskGenAlg :: MaskGenAlgorithm seed output    -- ^ mask Gen algorithm to use
    , forall hash seed output.
OAEPParams hash seed output -> Maybe ByteString
oaepLabel      :: Maybe ByteString                -- ^ optional label prepended to message
    }

-- | Default Params with a specified hash function.
defaultOAEPParams :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hash)
                  => hash
                  -> OAEPParams hash seed output
defaultOAEPParams :: forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
defaultOAEPParams hash
hashAlg =
    OAEPParams { oaepHash :: hash
oaepHash       = hash
hashAlg
               , oaepMaskGenAlg :: MaskGenAlgorithm seed output
oaepMaskGenAlg = hash -> MaskGenAlgorithm seed output
forall seed output hashAlg.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hashAlg) =>
hashAlg -> seed -> Int -> output
mgf1 hash
hashAlg
               , oaepLabel :: Maybe ByteString
oaepLabel      = Maybe ByteString
forall a. Maybe a
Nothing
               }

-- | Pad a message using OAEP.
pad :: HashAlgorithm hash
    => ByteString                               -- ^ Seed
    -> OAEPParams hash ByteString ByteString    -- ^ OAEP params to use
    -> Int                                      -- ^ size of public key in bytes
    -> ByteString                               -- ^ Message pad
    -> Either Error ByteString
pad :: forall hash.
HashAlgorithm hash =>
ByteString
-> OAEPParams hash ByteString ByteString
-> Int
-> ByteString
-> Either Error ByteString
pad ByteString
seed OAEPParams hash ByteString ByteString
oaep Int
k ByteString
msg
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
hashLenInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2          = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidParameters
    | ByteString -> Int
B.length ByteString
seed Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
hashLen = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidParameters
    | Int
mLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
hashLenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2   = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
MessageTooLong
    | Bool
otherwise                = ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right ByteString
em
    where -- parameters
        mLen :: Int
mLen       = ByteString -> Int
B.length ByteString
msg
        mgf :: MaskGenAlgorithm ByteString ByteString
mgf        = OAEPParams hash ByteString ByteString
-> MaskGenAlgorithm ByteString ByteString
forall hash seed output.
OAEPParams hash seed output -> MaskGenAlgorithm seed output
oaepMaskGenAlg OAEPParams hash ByteString ByteString
oaep
        labelHash :: Digest hash
labelHash  = hash -> ByteString -> Digest hash
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep) (ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
B.empty ByteString -> ByteString
forall a. a -> a
id (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ OAEPParams hash ByteString ByteString -> Maybe ByteString
forall hash seed output.
OAEPParams hash seed output -> Maybe ByteString
oaepLabel OAEPParams hash ByteString ByteString
oaep)
        hashLen :: Int
hashLen    = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep)
        -- put fields
        ps :: ByteString
ps         = Int -> Word8 -> ByteString
B.replicate (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) Word8
0
        db :: ByteString
db         = [ByteString] -> ByteString
B.concat [Digest hash -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert Digest hash
labelHash, ByteString
ps, Word8 -> ByteString
B.singleton Word8
0x1, ByteString
msg]
        dbmask :: ByteString
dbmask     = MaskGenAlgorithm ByteString ByteString
mgf ByteString
seed (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        maskedDB :: ByteString
maskedDB   = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
db ByteString
dbmask
        seedMask :: ByteString
seedMask   = MaskGenAlgorithm ByteString ByteString
mgf ByteString
maskedDB Int
hashLen
        maskedSeed :: ByteString
maskedSeed = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
seed ByteString
seedMask
        em :: ByteString
em         = [ByteString] -> ByteString
B.concat [Word8 -> ByteString
B.singleton Word8
0x0, ByteString
maskedSeed, ByteString
maskedDB]

-- | Un-pad a OAEP encoded message.
unpad :: HashAlgorithm hash
      => OAEPParams hash ByteString ByteString  -- ^ OAEP params to use
      -> Int                                    -- ^ size of public key in bytes
      -> ByteString                             -- ^ encoded message (not encrypted)
      -> Either Error ByteString
unpad :: forall hash.
HashAlgorithm hash =>
OAEPParams hash ByteString ByteString
-> Int -> ByteString -> Either Error ByteString
unpad OAEPParams hash ByteString ByteString
oaep Int
k ByteString
em
    | Bool
paddingSuccess = ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right ByteString
msg
    | Bool
otherwise      = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
MessageNotRecognized
    where -- parameters
        mgf :: MaskGenAlgorithm ByteString ByteString
mgf        = OAEPParams hash ByteString ByteString
-> MaskGenAlgorithm ByteString ByteString
forall hash seed output.
OAEPParams hash seed output -> MaskGenAlgorithm seed output
oaepMaskGenAlg OAEPParams hash ByteString ByteString
oaep
        labelHash :: ByteString
labelHash  = Digest hash -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (Digest hash -> ByteString) -> Digest hash -> ByteString
forall a b. (a -> b) -> a -> b
$ hash -> ByteString -> Digest hash
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep) (ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
B.empty ByteString -> ByteString
forall a. a -> a
id (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ OAEPParams hash ByteString ByteString -> Maybe ByteString
forall hash seed output.
OAEPParams hash seed output -> Maybe ByteString
oaepLabel OAEPParams hash ByteString ByteString
oaep)
        hashLen :: Int
hashLen    = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize (OAEPParams hash ByteString ByteString -> hash
forall hash seed output. OAEPParams hash seed output -> hash
oaepHash OAEPParams hash ByteString ByteString
oaep)
        -- getting em's fields
        (ByteString
pb, ByteString
em0)  = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
1 ByteString
em
        (ByteString
maskedSeed, ByteString
maskedDB) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
hashLen ByteString
em0
        seedMask :: ByteString
seedMask   = MaskGenAlgorithm ByteString ByteString
mgf ByteString
maskedDB Int
hashLen
        seed :: ByteString
seed       = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
maskedSeed ByteString
seedMask
        dbmask :: ByteString
dbmask     = MaskGenAlgorithm ByteString ByteString
mgf ByteString
seed (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
hashLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        db :: ByteString
db         = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
maskedDB ByteString
dbmask
        -- getting db's fields
        (ByteString
labelHash', ByteString
db1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
hashLen ByteString
db
        (ByteString
_, ByteString
db2)   = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0) ByteString
db1
        (ByteString
ps1, ByteString
msg) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
1 ByteString
db2

        paddingSuccess :: Bool
paddingSuccess = [Bool] -> Bool
and' [ ByteString
labelHash' ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
labelHash -- no need for constant eq
                              , ByteString
ps1        ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
B.replicate Int
1 Word8
0x1
                              , ByteString
pb         ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
B.replicate Int
1 Word8
0x0
                              ]