module Crypto.PubKey.Rabin.OAEP
( OAEPParams(..)
, defaultOAEPParams
, pad
, unpad
) where
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Bits (xor)
import Crypto.Hash
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray)
import qualified Crypto.Internal.ByteArray as B (convert)
import Crypto.PubKey.MaskGenFunction
import Crypto.PubKey.Internal (and')
import Crypto.PubKey.Rabin.Types
data OAEPParams hash seed output = OAEPParams
{ oaepHash :: hash
, oaepMaskGenAlg :: MaskGenAlgorithm seed output
, oaepLabel :: Maybe ByteString
}
defaultOAEPParams :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hash)
=> hash
-> OAEPParams hash seed output
defaultOAEPParams hashAlg =
OAEPParams { oaepHash = hashAlg
, oaepMaskGenAlg = mgf1 hashAlg
, oaepLabel = Nothing
}
pad :: HashAlgorithm hash
=> ByteString
-> OAEPParams hash ByteString ByteString
-> Int
-> ByteString
-> Either Error ByteString
pad seed oaep k msg
| k < 2*hashLen+2 = Left InvalidParameters
| B.length seed /= hashLen = Left InvalidParameters
| mLen > k - 2*hashLen-2 = Left MessageTooLong
| otherwise = Right em
where
mLen = B.length msg
mgf = oaepMaskGenAlg oaep
labelHash = hashWith (oaepHash oaep) (maybe B.empty id $ oaepLabel oaep)
hashLen = hashDigestSize (oaepHash oaep)
ps = B.replicate (k - mLen - 2*hashLen - 2) 0
db = B.concat [B.convert labelHash, ps, B.singleton 0x1, msg]
dbmask = mgf seed (k - hashLen - 1)
maskedDB = B.pack $ B.zipWith xor db dbmask
seedMask = mgf maskedDB hashLen
maskedSeed = B.pack $ B.zipWith xor seed seedMask
em = B.concat [B.singleton 0x0, maskedSeed, maskedDB]
unpad :: HashAlgorithm hash
=> OAEPParams hash ByteString ByteString
-> Int
-> ByteString
-> Either Error ByteString
unpad oaep k em
| paddingSuccess = Right msg
| otherwise = Left MessageNotRecognized
where
mgf = oaepMaskGenAlg oaep
labelHash = B.convert $ hashWith (oaepHash oaep) (maybe B.empty id $ oaepLabel oaep)
hashLen = hashDigestSize (oaepHash oaep)
(pb, em0) = B.splitAt 1 em
(maskedSeed, maskedDB) = B.splitAt hashLen em0
seedMask = mgf maskedDB hashLen
seed = B.pack $ B.zipWith xor maskedSeed seedMask
dbmask = mgf seed (k - hashLen - 1)
db = B.pack $ B.zipWith xor maskedDB dbmask
(labelHash', db1) = B.splitAt hashLen db
(_, db2) = B.break (/= 0) db1
(ps1, msg) = B.splitAt 1 db2
paddingSuccess = and' [ labelHash' == labelHash
, ps1 == B.replicate 1 0x1
, pb == B.replicate 1 0x0
]