-- SecretKey.hs: OpenPGP (RFC4880) secret key encryption/decryption
-- Copyright © 2013-2016  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

module Codec.Encryption.OpenPGP.SecretKey (
   decryptPrivateKey
 , encryptPrivateKey
 , encryptPrivateKeyIO
 , reencryptSecretKeyIO
) where

import Codec.Encryption.OpenPGP.Internal.HOBlockCipher
import Codec.Encryption.OpenPGP.Types
import Codec.Encryption.OpenPGP.BlockCipher (withSymmetricCipher, keySize)
import Codec.Encryption.OpenPGP.CFB (decryptNoNonce, encryptNoNonce)
import Codec.Encryption.OpenPGP.Serialize (getSecretKey)
import Codec.Encryption.OpenPGP.S2K (skesk2Key, string2Key)
import qualified Crypto.Hash as CH
import Crypto.Number.ModArithmetic (inverse)
import Crypto.Random.EntropyPool (createEntropyPool, getEntropyFrom)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Binary (put)
import Data.Binary.Get (getRemainingLazyByteString, getWord16be, runGetOrFail)
import Data.Binary.Put (runPut)
import Data.Bifunctor (bimap)
import qualified Crypto.PubKey.RSA as R

saBlockSize :: SymmetricAlgorithm -> Int
saBlockSize sa = either (const 0) id (withSymmetricCipher sa B.empty (Right . blockSize))

decryptPrivateKey :: (PKPayload, SKAddendum) -> BL.ByteString -> SKAddendum
decryptPrivateKey (pkp, ska@SUS16bit{}) pp = either (error "could not decrypt SUS16bit") id (decryptSKA (pkp, ska) pp)
decryptPrivateKey (pkp, ska@SUSSHA1{}) pp = either (error "could not decrypt SUSSHA1") id (decryptSKA (pkp, ska) pp)
decryptPrivateKey (_, SUSym{}) _ = error "SUSym key decryption not implemented"
decryptPrivateKey (_, ska@SUUnencrypted{}) _ = ska

decryptSKA :: (PKPayload, SKAddendum) -> BL.ByteString -> Either String SKAddendum
decryptSKA (pkp, SUS16bit sa s2k iv payload) pp = do
    let key = skesk2Key (SKESK 4 sa s2k Nothing) pp
    p <- decryptNoNonce sa iv (BL.toStrict payload) key
    (s, cksum) <- getSecretKeyAndChecksum p -- FIXME: check the 16bit hash
    let checksum = cksum
    return $ SUUnencrypted s checksum  -- FIXME: is this the correct checksum?
    where
        getSecretKeyAndChecksum p = bimap (\(_,_,x) -> x) (\(_,_,x) -> x) (runGetOrFail (getSecretKey pkp >>= \sk -> getWord16be >>= \csum -> return (sk, csum)) (BL.fromStrict p)) -- FIXME: check the 16bit hash
decryptSKA (pkp, SUSSHA1 sa s2k iv payload) pp = do
    let key = skesk2Key (SKESK 4 sa s2k Nothing) pp
    p <- decryptNoNonce sa iv (BL.toStrict payload) key
    (s, cksum) <- getSecretKeyAndChecksum p -- FIXME: check the SHA1 hash
    let checksum = sum . map fromIntegral . B.unpack . B.take (B.length p - 20) $ p
    return $ SUUnencrypted s checksum  -- FIXME: is this the correct checksum?
    where
        getSecretKeyAndChecksum p = bimap (\(_,_,x) -> x) (\(_,_,x) -> x) (runGetOrFail (getSecretKey pkp >>= \sk -> getRemainingLazyByteString >>= \csum -> return (sk, csum)) (BL.fromStrict p))
decryptSKA _ _ = Left "Unexpected codepath"

-- |generates pseudo-random salt and IV
encryptPrivateKeyIO :: SKAddendum -> BL.ByteString -> IO SKAddendum
encryptPrivateKeyIO ska pp = saltiv >>= \(s,i) -> return (encryptPrivateKey s (IV i) ska pp)
    where
        saltiv = do
                    ep <- createEntropyPool
                    bb <- getEntropyFrom ep (8 + saBlockSize AES256)
                    return $ B.splitAt 8 bb

-- |8-octet salt, IV must be length of cipher blocksize
encryptPrivateKey :: B.ByteString -> IV -> SKAddendum -> BL.ByteString -> SKAddendum
encryptPrivateKey _ _ ska@SUS16bit{} _ = ska
encryptPrivateKey _ _ ska@SUSSHA1{} _ = ska
encryptPrivateKey _ _ ska@SUSym{} _ = ska
encryptPrivateKey salt iv (SUUnencrypted skey _) pp = SUSSHA1 AES256 s2k iv (BL.fromStrict (encryptSKey skey s2k iv pp))
    where
       s2k = IteratedSalted SHA512 (Salt salt) 12058624

encryptSKey :: SKey -> S2K -> IV -> BL.ByteString -> B.ByteString
encryptSKey (RSAPrivateKey (RSA_PrivateKey (R.PrivateKey _ d p q _ _ _))) s2k iv pp = either error id (encryptNoNonce AES256 s2k iv (BL.toStrict payload) key)
    where
        key = string2Key s2k (keySize AES256) pp
        algospecific = runPut $ put (MPI d) >> put (MPI p) >> put (MPI q) >> put (MPI u)
        cksum = CH.hashlazy algospecific :: CH.Digest CH.SHA1
        payload = algospecific `BL.append` BL.fromStrict (BA.convert cksum)
        Just u = inverse p q
encryptSKey _ _ _ _ = error "Non-RSA keytypes not handled yet" -- FIXME: do DSA and ElGamal

reencryptSecretKeyIO :: SecretKey -> BL.ByteString -> IO SecretKey
reencryptSecretKeyIO sk pp = encryptPrivateKeyIO (_secretKeySKAddendum sk) pp >>= \n -> return sk { _secretKeySKAddendum = n }