{-# LANGUAGE CPP #-} module Crypto.PubKey.ECIES.Conduit ( encrypt , decrypt ) where import Control.Monad.Catch (MonadThrow, throwM) import Control.Monad.Trans.Class (lift) import qualified Crypto.Cipher.ChaCha as ChaCha import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha import qualified Crypto.ECC as ECC import qualified Crypto.Error as CE import Crypto.Hash (SHA512 (..), hashWith) import Crypto.PubKey.ECIES (deriveDecrypt, deriveEncrypt) import Crypto.Random (MonadRandom) import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import Data.Conduit (ConduitM, yield) import qualified Data.Conduit.Binary as CB import Data.Proxy (Proxy (..)) import System.IO.Unsafe (unsafePerformIO) getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString) getNonceKey shared = let state1 = ChaCha.initializeSimple $ B.take 40 $ BA.convert $ hashWith SHA512 shared (nonce, state2) = ChaCha.generateSimple state1 12 (key, _) = ChaCha.generateSimple state2 32 in (nonce, key) type Curve = ECC.Curve_P256R1 proxy :: Proxy Curve proxy = Proxy pointBinarySize :: Int pointBinarySize = B.length $ ECC.encodePoint proxy point where point = unsafePerformIO (ECC.keypairGetPublic <$> ECC.curveGenerateKeyPair proxy) {-# NOINLINE pointBinarySize #-} throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a throwOnFail (CE.CryptoPassed a) = pure a throwOnFail (CE.CryptoFailed e) = throwM e encrypt :: (MonadThrow m, MonadRandom m) => ECC.Point Curve -> ConduitM ByteString ByteString m () encrypt point = do (point', shared) <- lift (deriveEncryptCompat proxy point) >>= throwOnFail let (nonce, key) = getNonceKey shared yield $ ECC.encodePoint proxy point' ChaCha.encrypt nonce key where #if MIN_VERSION_cryptonite(0,23,999) deriveEncryptCompat prx p = deriveEncrypt prx p #else deriveEncryptCompat prx p = CE.CryptoPassed <$> deriveEncrypt prx p #endif decrypt :: (MonadThrow m) => ECC.Scalar Curve -> ConduitM ByteString ByteString m () decrypt scalar = do pointBS <- fmap BL.toStrict $ CB.take pointBinarySize point <- throwOnFail (ECC.decodePoint proxy pointBS) shared <- throwOnFail (deriveDecryptCompat proxy point scalar) let (_nonce, key) = getNonceKey shared ChaCha.decrypt key where #if MIN_VERSION_cryptonite(0,23,999) deriveDecryptCompat prx p s = deriveDecrypt prx p s #else deriveDecryptCompat prx p s = CE.CryptoPassed (deriveDecrypt prx p s) #endif