{-# LANGUAGE DeriveDataTypeable #-}
module Crypto.PubKey.Rabin.Basic
( PublicKey(..)
, PrivateKey(..)
, Signature(..)
, generate
, encrypt
, encryptWithSeed
, decrypt
, sign
, signWith
, verify
) where
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Data
import Data.Either (rights)
import Crypto.Hash
import Crypto.Number.Basic (gcde, numBytes)
import Crypto.Number.ModArithmetic (expSafe, jacobi)
import Crypto.Number.Serialize (i2osp, i2ospOf_, os2ip)
import Crypto.PubKey.Rabin.OAEP
import Crypto.PubKey.Rabin.Types
import Crypto.Random (MonadRandom, getRandomBytes)
data PublicKey = PublicKey
{ public_size :: Int
, public_n :: Integer
} deriving (Show, Read, Eq, Data)
data PrivateKey = PrivateKey
{ private_pub :: PublicKey
, private_p :: Integer
, private_q :: Integer
, private_a :: Integer
, private_b :: Integer
} deriving (Show, Read, Eq, Data)
data Signature = Signature (Integer, Integer) deriving (Show, Read, Eq, Data)
generate :: MonadRandom m
=> Int
-> m (PublicKey, PrivateKey)
generate size = do
(p, q) <- generatePrimes size (\p -> p `mod` 4 == 3) (\q -> q `mod` 4 == 3)
return $ generateKeys p q
where
generateKeys p q =
let n = p*q
(a, b, _) = gcde p q
publicKey = PublicKey { public_size = size
, public_n = n }
privateKey = PrivateKey { private_pub = publicKey
, private_p = p
, private_q = q
, private_a = a
, private_b = b }
in (publicKey, privateKey)
encryptWithSeed :: HashAlgorithm hash
=> ByteString
-> OAEPParams hash ByteString ByteString
-> PublicKey
-> ByteString
-> Either Error ByteString
encryptWithSeed seed oaep pk m =
let n = public_n pk
k = numBytes n
in do
m' <- pad seed oaep k m
let m'' = os2ip m'
return $ i2osp $ expSafe m'' 2 n
encrypt :: (HashAlgorithm hash, MonadRandom m)
=> OAEPParams hash ByteString ByteString
-> PublicKey
-> ByteString
-> m (Either Error ByteString)
encrypt oaep pk m = do
seed <- getRandomBytes hashLen
return $ encryptWithSeed seed oaep pk m
where
hashLen = hashDigestSize (oaepHash oaep)
decrypt :: HashAlgorithm hash
=> OAEPParams hash ByteString ByteString
-> PrivateKey
-> ByteString
-> Maybe ByteString
decrypt oaep pk c =
let p = private_p pk
q = private_q pk
a = private_a pk
b = private_b pk
n = public_n $ private_pub pk
k = numBytes n
c' = os2ip c
solutions = rights $ toList $ mapTuple (unpad oaep k . i2ospOf_ k) $ sqroot' c' p q a b n
in if length solutions /= 1 then Nothing
else Just $ head solutions
where toList (w, x, y, z) = w:x:y:z:[]
mapTuple f (w, x, y, z) = (f w, f x, f y, f z)
signWith :: HashAlgorithm hash
=> ByteString
-> PrivateKey
-> hash
-> ByteString
-> Either Error Signature
signWith padding pk hashAlg m = do
h <- calculateHash padding pk hashAlg m
signature <- calculateSignature h
return signature
where
calculateSignature h =
let p = private_p pk
q = private_q pk
a = private_a pk
b = private_b pk
n = public_n $ private_pub pk
in if h >= n then Left MessageTooLong
else let (r, _, _, _) = sqroot' h p q a b n
in Right $ Signature (os2ip padding, r)
sign :: (MonadRandom m, HashAlgorithm hash)
=> PrivateKey
-> hash
-> ByteString
-> m (Either Error Signature)
sign pk hashAlg m = do
padding <- findPadding
return $ signWith padding pk hashAlg m
where
findPadding = do
padding <- getRandomBytes 8
case calculateHash padding pk hashAlg m of
Right _ -> return padding
_ -> findPadding
calculateHash :: HashAlgorithm hash
=> ByteString
-> PrivateKey
-> hash
-> ByteString
-> Either Error Integer
calculateHash padding pk hashAlg m =
let p = private_p pk
q = private_q pk
h = os2ip $ hashWith hashAlg $ B.append padding m
in case (jacobi (h `mod` p) p, jacobi (h `mod` q) q) of
(Just 1, Just 1) -> Right h
_ -> Left InvalidParameters
verify :: HashAlgorithm hash
=> PublicKey
-> hash
-> ByteString
-> Signature
-> Bool
verify pk hashAlg m (Signature (padding, s)) =
let n = public_n pk
p = i2osp padding
h = os2ip $ hashWith hashAlg $ B.append p m
h' = expSafe s 2 n
in h' == h
sqroot :: Integer
-> Integer
-> (Integer, Integer)
sqroot a p =
let r = expSafe a ((p + 1) `div` 4) p
in (r, -r)
sqroot' :: Integer
-> Integer
-> Integer
-> Integer
-> Integer
-> Integer
-> (Integer, Integer, Integer, Integer)
sqroot' a p q c d n =
let (r, _) = sqroot a p
(s, _) = sqroot a q
x = (r*d*q + s*c*p) `mod` n
y = (r*d*q - s*c*p) `mod` n
in (x, (-x) `mod` n, y, (-y) `mod` n)