-----------------------------------------------------------------------------
-- |
-- Module      :  Coded.Encryption.RSA.EMEOAEP
-- Copyright   :  (c) David J. Sankel 2003, Dominic Steinitz 2003
-- License     :  GPL (see the file ReadMe.tex)
--
-- Stability   :  experimental
-- Portability :  non-portable
--
-- A modified version of the EMEOAEP module supplied by David J. Sankel
-- (<http://www.electronconsulting.com/rsa-haskell>).
--
-- As the original code is GPL, this has to be.
-- This code is free software; you can redistribute it and\/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation; either version 2 of the License, or
-- (at your option) any later version.
--
-- This code is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this code; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111\-1307  USA
-----------------------------------------------------------------------------

module Codec.Encryption.RSA.EMEOAEP(
   -- * Function Types
   encode,
   decode
   )where

import Codec.Utils (Octet)
import Data.Bits

xorOctets :: Bits a => [a] -> [a] -> [a]
xorOctets :: forall a. Bits a => [a] -> [a] -> [a]
xorOctets = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor

-- | Take a mask generating function, a hash function, a label (which may be
--   null), a random seed, the modulus of the key and the message and returns
--   an encoded message. NB you could pass in the length of the modulus
--   but it seems safer to pass in the modulus itself and calculate the
--   length when required. See
--   <ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1.pdf> for more
--   details.

encode :: (([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]) ->
          ([Octet] -> [Octet]) -> [Octet] -> [Octet] -> [Octet] -> [Octet] ->
          [Octet]

encode :: (([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet])
-> ([Octet] -> [Octet])
-> [Octet]
-> [Octet]
-> [Octet]
-> [Octet]
-> [Octet]
encode ([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]
mgf [Octet] -> [Octet]
hash [Octet]
p [Octet]
seed [Octet]
n [Octet]
m =
   if  forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
m forall a. Ord a => a -> a -> Bool
> Int
emLen forall a. Num a => a -> a -> a
- Int
2forall a. Num a => a -> a -> a
*Int
hLen forall a. Num a => a -> a -> a
- Int
2
     then forall a. HasCallStack => [Char] -> a
error [Char]
"Codec.Encryption.EMEOAEP.encode: message too long"
     else [Octet]
em
        where
           emLen :: Int
emLen      = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
n
           mLen :: Int
mLen       = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
m
           ps :: [Octet]
ps         = forall a. Int -> [a] -> [a]
take (Int
emLenforall a. Num a => a -> a -> a
-Int
mLenforall a. Num a => a -> a -> a
-Int
2forall a. Num a => a -> a -> a
*Int
hLenforall a. Num a => a -> a -> a
-Int
2) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ Octet
0x00
           pHash :: [Octet]
pHash      = [Octet] -> [Octet]
hash [Octet]
p
           hLen :: Int
hLen       = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
pHash
           db :: [Octet]
db         = [Octet]
pHash forall a. [a] -> [a] -> [a]
++ [Octet]
ps forall a. [a] -> [a] -> [a]
++ [Octet
0x01] forall a. [a] -> [a] -> [a]
++ [Octet]
m
           dbMask :: [Octet]
dbMask     = ([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]
mgf [Octet] -> [Octet]
hash [Octet]
seed (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
db)
           maskedDB :: [Octet]
maskedDB   = [Octet]
db forall a. Bits a => [a] -> [a] -> [a]
`xorOctets` [Octet]
dbMask
           seedMask :: [Octet]
seedMask   = ([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]
mgf [Octet] -> [Octet]
hash [Octet]
maskedDB Int
hLen
           maskedSeed :: [Octet]
maskedSeed = [Octet]
seed forall a. Bits a => [a] -> [a] -> [a]
`xorOctets` [Octet]
seedMask
           em :: [Octet]
em         = [Octet
0x00] forall a. [a] -> [a] -> [a]
++ [Octet]
maskedSeed forall a. [a] -> [a] -> [a]
++ [Octet]
maskedDB

-- | Take a mask generating function, a hash function, a label (which may be
--   null) and the message and returns the decoded.

decode :: (([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]) ->
          ([Octet] -> [Octet]) -> [Octet] -> [Octet] -> [Octet]

decode :: (([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet])
-> ([Octet] -> [Octet]) -> [Octet] -> [Octet] -> [Octet]
decode ([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]
mgf [Octet] -> [Octet]
hash [Octet]
p [Octet]
em =
  if forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
em forall a. Ord a => a -> a -> Bool
< Int
2forall a. Num a => a -> a -> a
*Int
hLen forall a. Num a => a -> a -> a
+ Int
1 Bool -> Bool -> Bool
||
     Octet
one forall a. Eq a => a -> a -> Bool
/= Octet
0x01 Bool -> Bool -> Bool
||
     [Octet]
pHash' forall a. Eq a => a -> a -> Bool
/= [Octet]
pHash Bool -> Bool -> Bool
||
     [Octet]
y forall a. Eq a => a -> a -> Bool
/= [Octet
0x00]
     then forall a. HasCallStack => [Char] -> a
error [Char]
"Codec.Encryption.EMEOAEP.decode: decryption error"
     else [Octet]
m
        where
           ([Octet]
y,[Octet]
rest)              = forall a. Int -> [a] -> ([a], [a])
splitAt Int
1 [Octet]
em
           pHash :: [Octet]
pHash                 = [Octet] -> [Octet]
hash [Octet]
p
           hLen :: Int
hLen                  = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
pHash
           ([Octet]
maskedSeed,[Octet]
maskedDB) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
hLen [Octet]
rest
           seedMask :: [Octet]
seedMask              = ([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]
mgf [Octet] -> [Octet]
hash [Octet]
maskedDB Int
hLen
           seed :: [Octet]
seed                  = [Octet]
maskedSeed forall a. Bits a => [a] -> [a] -> [a]
`xorOctets` [Octet]
seedMask
           emLen :: Int
emLen                 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Octet]
em
           dbMask :: [Octet]
dbMask                = ([Octet] -> [Octet]) -> [Octet] -> Int -> [Octet]
mgf [Octet] -> [Octet]
hash [Octet]
seed (Int
emLen forall a. Num a => a -> a -> a
- Int
hLen forall a. Num a => a -> a -> a
- Int
1)
           db :: [Octet]
db                    = [Octet]
maskedDB forall a. Bits a => [a] -> [a] -> [a]
`xorOctets` [Octet]
dbMask
           ([Octet]
pHash',[Octet]
rest')        = forall a. Int -> [a] -> ([a], [a])
splitAt Int
hLen [Octet]
db
           (Octet
one:[Octet]
m)               = forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
== Octet
0x00) [Octet]
rest'