module Crypto.PubKey.RSA.PSS
    ( PSSParams(..)
    , defaultPSSParams
    , defaultPSSParamsSHA1
    
    , signWithSalt
    , signDigestWithSalt
    , sign
    , signDigest
    , signSafer
    , signDigestSafer
    , verify
    , verifyDigest
    ) where
import           Crypto.Random.Types
import           Crypto.PubKey.RSA.Types
import           Crypto.PubKey.RSA.Prim
import           Crypto.PubKey.RSA (generateBlinder)
import           Crypto.PubKey.MaskGenFunction
import           Crypto.Hash
import           Data.Bits (xor, shiftR, (.&.))
import           Data.Word
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray)
import qualified Crypto.Internal.ByteArray as B (convert)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
data PSSParams hash seed output = PSSParams
    { pssHash         :: hash             
    , pssMaskGenAlg   :: MaskGenAlgorithm seed output 
    , pssSaltLength   :: Int              
    , pssTrailerField :: Word8            
    }
defaultPSSParams :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hash)
                 => hash
                 -> PSSParams hash seed output
defaultPSSParams hashAlg =
    PSSParams { pssHash         = hashAlg
              , pssMaskGenAlg   = mgf1 hashAlg
              , pssSaltLength   = hashDigestSize hashAlg
              , pssTrailerField = 0xbc
              }
defaultPSSParamsSHA1 :: PSSParams SHA1 ByteString ByteString
defaultPSSParamsSHA1 = defaultPSSParams SHA1
signDigestWithSalt :: HashAlgorithm hash
                   => ByteString    
                   -> Maybe Blinder 
                   -> PSSParams hash ByteString ByteString 
                   -> PrivateKey    
                   -> Digest hash   
                   -> Either Error ByteString
signDigestWithSalt salt blinder params pk digest
    | k < hashLen + saltLen + 2 = Left InvalidParameters
    | otherwise                 = Right $ dp blinder pk em
    where k        = private_size pk
          mHash    = B.convert digest
          dbLen    = k  hashLen  1
          saltLen  = B.length salt
          hashLen  = hashDigestSize (pssHash params)
          pubBits  = private_size pk * 8 
          m'       = B.concat [B.replicate 8 0,mHash,salt]
          h        = B.convert $ hashWith (pssHash params) m'
          db       = B.concat [B.replicate (dbLen  saltLen  1) 0,B.singleton 1,salt]
          dbmask   = (pssMaskGenAlg params) h dbLen
          maskedDB = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor db dbmask
          em       = B.concat [maskedDB, h, B.singleton (pssTrailerField params)]
signWithSalt :: HashAlgorithm hash
             => ByteString    
             -> Maybe Blinder 
             -> PSSParams hash ByteString ByteString 
             -> PrivateKey    
             -> ByteString    
             -> Either Error ByteString
signWithSalt salt blinder params pk m = signDigestWithSalt salt blinder params pk mHash
    where mHash    = hashWith (pssHash params) m
sign :: (HashAlgorithm hash, MonadRandom m)
     => Maybe Blinder   
     -> PSSParams hash ByteString ByteString 
     -> PrivateKey      
     -> ByteString      
     -> m (Either Error ByteString)
sign blinder params pk m = do
    salt <- getRandomBytes (pssSaltLength params)
    return (signWithSalt salt blinder params pk m)
signDigest :: (HashAlgorithm hash, MonadRandom m)
           => Maybe Blinder   
           -> PSSParams hash ByteString ByteString 
           -> PrivateKey      
           -> Digest hash     
           -> m (Either Error ByteString)
signDigest blinder params pk digest = do
    salt <- getRandomBytes (pssSaltLength params)
    return (signDigestWithSalt salt blinder params pk digest)
signSafer :: (HashAlgorithm hash, MonadRandom m)
          => PSSParams hash ByteString ByteString 
          -> PrivateKey     
          -> ByteString     
          -> m (Either Error ByteString)
signSafer params pk m = do
    blinder <- generateBlinder (private_n pk)
    sign (Just blinder) params pk m
signDigestSafer :: (HashAlgorithm hash, MonadRandom m)
                => PSSParams hash ByteString ByteString 
                -> PrivateKey     
                -> Digest hash    
                -> m (Either Error ByteString)
signDigestSafer params pk digest = do
    blinder <- generateBlinder (private_n pk)
    signDigest (Just blinder) params pk digest
verify :: HashAlgorithm hash
       => PSSParams hash ByteString ByteString
                     
                     
       -> PublicKey  
       -> ByteString 
       -> ByteString 
       -> Bool
verify params pk m s = verifyDigest params pk mHash s
  where mHash     = hashWith (pssHash params) m
verifyDigest :: HashAlgorithm hash
             => PSSParams hash ByteString ByteString
                            
                            
             -> PublicKey   
             -> Digest hash 
             -> ByteString  
             -> Bool
verifyDigest params pk digest s
    | public_size pk /= B.length s        = False
    | B.last em /= pssTrailerField params = False
    | not (B.all (== 0) ps0)              = False
    | b1 /= B.singleton 1                 = False
    | otherwise                           = h == B.convert h'
        where 
              hashLen   = hashDigestSize (pssHash params)
              mHash     = B.convert digest
              dbLen     = public_size pk  hashLen  1
              pubBits   = public_size pk * 8 
              
              em        = ep pk s
              maskedDB  = B.take (B.length em  hashLen  1) em
              h         = B.take hashLen $ B.drop (B.length maskedDB) em
              dbmask    = (pssMaskGenAlg params) h dbLen
              db        = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor maskedDB dbmask
              (ps0,z)   = B.break (== 1) db
              (b1,salt) = B.splitAt 1 z
              m'        = B.concat [B.replicate 8 0,mHash,salt]
              h'        = hashWith (pssHash params) m'
normalizeToKeySize :: Int -> [Word8] -> [Word8]
normalizeToKeySize _    []     = [] 
normalizeToKeySize bits (x:xs) = x .&. mask : xs
    where mask = if sh > 0 then 0xff `shiftR` (8sh) else 0xff
          sh   = ((bits1) .&. 0x7)