{-# LANGUAGE CPP #-} module Crypto.PubKey.ECIES.Conduit ( encrypt , decrypt ) where import Control.Monad.Catch (MonadThrow, throwM) import Control.Monad.Trans.Class (lift) import qualified Crypto.Cipher.ChaCha as ChaCha import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha import qualified Crypto.ECC as ECC import qualified Crypto.Error as CE import Crypto.Hash (SHA512 (..), hashWith) import Crypto.PubKey.ECIES (deriveDecrypt, deriveEncrypt) import Crypto.Random (MonadRandom) import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import Data.Conduit (ConduitM, yield) import qualified Data.Conduit.Binary as CB import Data.Proxy (Proxy (..)) import System.IO.Unsafe (unsafePerformIO) getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString) getNonceKey :: SharedSecret -> (ByteString, ByteString) getNonceKey SharedSecret shared = let state1 :: StateSimple state1 = forall seed. ByteArrayAccess seed => seed -> StateSimple ChaCha.initializeSimple forall a b. (a -> b) -> a -> b $ Int -> ByteString -> ByteString B.take Int 40 forall a b. (a -> b) -> a -> b $ forall bin bout. (ByteArrayAccess bin, ByteArray bout) => bin -> bout BA.convert forall a b. (a -> b) -> a -> b $ forall ba alg. (ByteArrayAccess ba, HashAlgorithm alg) => alg -> ba -> Digest alg hashWith SHA512 SHA512 SharedSecret shared (ByteString nonce, StateSimple state2) = forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple) ChaCha.generateSimple StateSimple state1 Int 12 (ByteString key, StateSimple _) = forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple) ChaCha.generateSimple StateSimple state2 Int 32 in (ByteString nonce, ByteString key) type Curve = ECC.Curve_P256R1 proxy :: Proxy Curve proxy :: Proxy Curve proxy = forall {k} (t :: k). Proxy t Proxy pointBinarySize :: Int pointBinarySize :: Int pointBinarySize = ByteString -> Int B.length forall a b. (a -> b) -> a -> b $ forall curve bs (proxy :: * -> *). (EllipticCurve curve, ByteArray bs) => proxy curve -> Point curve -> bs ECC.encodePoint Proxy Curve proxy Point point where point :: Point point = forall a. IO a -> a unsafePerformIO (forall curve. KeyPair curve -> Point curve ECC.keypairGetPublic forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> forall curve (randomly :: * -> *) (proxy :: * -> *). (EllipticCurve curve, MonadRandom randomly) => proxy curve -> randomly (KeyPair curve) ECC.curveGenerateKeyPair Proxy Curve proxy) {-# NOINLINE pointBinarySize #-} throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a throwOnFail :: forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a throwOnFail (CE.CryptoPassed a a) = forall (f :: * -> *) a. Applicative f => a -> f a pure a a throwOnFail (CE.CryptoFailed CryptoError e) = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a throwM CryptoError e encrypt :: (MonadThrow m, MonadRandom m) => ECC.Point Curve -> ConduitM ByteString ByteString m () encrypt :: forall (m :: * -> *). (MonadThrow m, MonadRandom m) => Point Curve -> ConduitM ByteString ByteString m () encrypt Point Curve point = do (Point point', SharedSecret shared) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (forall {randomly :: * -> *} {curve} {proxy :: * -> *}. (MonadRandom randomly, EllipticCurveDH curve) => proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) deriveEncryptCompat Proxy Curve proxy Point Curve point) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a throwOnFail let (ByteString nonce, ByteString key) = SharedSecret -> (ByteString, ByteString) getNonceKey SharedSecret shared forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m () yield forall a b. (a -> b) -> a -> b $ forall curve bs (proxy :: * -> *). (EllipticCurve curve, ByteArray bs) => proxy curve -> Point curve -> bs ECC.encodePoint Proxy Curve proxy Point point' forall (m :: * -> *). MonadThrow m => ByteString -> ByteString -> ConduitM ByteString ByteString m () ChaCha.encrypt ByteString nonce ByteString key where deriveEncryptCompat :: proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) deriveEncryptCompat proxy curve prx Point curve p = forall {randomly :: * -> *} {curve} {proxy :: * -> *}. (MonadRandom randomly, EllipticCurveDH curve) => proxy curve -> Point curve -> randomly (CryptoFailable (Point curve, SharedSecret)) deriveEncrypt proxy curve prx Point curve p decrypt :: (MonadThrow m) => ECC.Scalar Curve -> ConduitM ByteString ByteString m () decrypt :: forall (m :: * -> *). MonadThrow m => Scalar Curve -> ConduitM ByteString ByteString m () decrypt Scalar Curve scalar = do ByteString pointBS <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap ByteString -> ByteString BL.toStrict forall a b. (a -> b) -> a -> b $ forall (m :: * -> *) o. Monad m => Int -> ConduitT ByteString o m ByteString CB.take Int pointBinarySize Point point <- forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a throwOnFail (forall curve bs (proxy :: * -> *). (EllipticCurve curve, ByteArray bs) => proxy curve -> bs -> CryptoFailable (Point curve) ECC.decodePoint Proxy Curve proxy ByteString pointBS) SharedSecret shared <- forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a throwOnFail (forall {curve} {proxy :: * -> *}. EllipticCurveDH curve => proxy curve -> Point curve -> Scalar curve -> CryptoFailable SharedSecret deriveDecryptCompat Proxy Curve proxy Point point Scalar Curve scalar) let (ByteString _nonce, ByteString key) = SharedSecret -> (ByteString, ByteString) getNonceKey SharedSecret shared forall (m :: * -> *). MonadThrow m => ByteString -> ConduitM ByteString ByteString m () ChaCha.decrypt ByteString key where deriveDecryptCompat :: proxy curve -> Point curve -> Scalar curve -> CryptoFailable SharedSecret deriveDecryptCompat proxy curve prx Point curve p Scalar curve s = forall {curve} {proxy :: * -> *}. EllipticCurveDH curve => proxy curve -> Point curve -> Scalar curve -> CryptoFailable SharedSecret deriveDecrypt proxy curve prx Point curve p Scalar curve s