{-# LANGUAGE CPP #-}
module Crypto.PubKey.ECIES.Conduit
  ( encrypt
  , decrypt
  ) where

import           Control.Monad.Catch                  (MonadThrow, throwM)
import           Control.Monad.Trans.Class            (lift)
import qualified Crypto.Cipher.ChaCha                 as ChaCha
import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha
import qualified Crypto.ECC                           as ECC
import qualified Crypto.Error                         as CE
import           Crypto.Hash                          (SHA512 (..), hashWith)
import           Crypto.PubKey.ECIES                  (deriveDecrypt,
                                                       deriveEncrypt)
import           Crypto.Random                        (MonadRandom)
import qualified Data.ByteArray                       as BA
import           Data.ByteString                      (ByteString)
import qualified Data.ByteString                      as B
import qualified Data.ByteString.Lazy                 as BL
import           Data.Conduit                         (ConduitM, yield)
import qualified Data.Conduit.Binary                  as CB
import           Data.Proxy                           (Proxy (..))
import           System.IO.Unsafe                     (unsafePerformIO)

getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString)
getNonceKey :: SharedSecret -> (ByteString, ByteString)
getNonceKey SharedSecret
shared =
  let state1 :: StateSimple
state1 = forall seed. ByteArrayAccess seed => seed -> StateSimple
ChaCha.initializeSimple forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
40 forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert forall a b. (a -> b) -> a -> b
$ forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith SHA512
SHA512 SharedSecret
shared
      (ByteString
nonce, StateSimple
state2) = forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
ChaCha.generateSimple StateSimple
state1 Int
12
      (ByteString
key, StateSimple
_) = forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
ChaCha.generateSimple StateSimple
state2 Int
32
   in (ByteString
nonce, ByteString
key)

type Curve = ECC.Curve_P256R1

proxy :: Proxy Curve
proxy :: Proxy Curve
proxy = forall {k} (t :: k). Proxy t
Proxy

pointBinarySize :: Int
pointBinarySize :: Int
pointBinarySize = ByteString -> Int
B.length forall a b. (a -> b) -> a -> b
$ forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
ECC.encodePoint Proxy Curve
proxy Point
point
  where
    point :: Point
point = forall a. IO a -> a
unsafePerformIO (forall curve. KeyPair curve -> Point curve
ECC.keypairGetPublic forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall curve (randomly :: * -> *) (proxy :: * -> *).
(EllipticCurve curve, MonadRandom randomly) =>
proxy curve -> randomly (KeyPair curve)
ECC.curveGenerateKeyPair Proxy Curve
proxy)
{-# NOINLINE pointBinarySize #-}

throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a
throwOnFail :: forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail (CE.CryptoPassed a
a) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
throwOnFail (CE.CryptoFailed CryptoError
e) = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM CryptoError
e


encrypt
  :: (MonadThrow m, MonadRandom m)
  => ECC.Point Curve
  -> ConduitM ByteString ByteString m ()
encrypt :: forall (m :: * -> *).
(MonadThrow m, MonadRandom m) =>
Point Curve -> ConduitM ByteString ByteString m ()
encrypt Point Curve
point = do
  (Point
point', SharedSecret
shared) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall {randomly :: * -> *} {curve} {proxy :: * -> *}.
(MonadRandom randomly, EllipticCurveDH curve) =>
proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncryptCompat Proxy Curve
proxy Point Curve
point) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail
  let (ByteString
nonce, ByteString
key) = SharedSecret -> (ByteString, ByteString)
getNonceKey SharedSecret
shared
  forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield forall a b. (a -> b) -> a -> b
$ forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
ECC.encodePoint Proxy Curve
proxy Point
point'
  forall (m :: * -> *).
MonadThrow m =>
ByteString -> ByteString -> ConduitM ByteString ByteString m ()
ChaCha.encrypt ByteString
nonce ByteString
key
  where
    deriveEncryptCompat :: proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncryptCompat proxy curve
prx Point curve
p = forall {randomly :: * -> *} {curve} {proxy :: * -> *}.
(MonadRandom randomly, EllipticCurveDH curve) =>
proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncrypt proxy curve
prx Point curve
p

decrypt
  :: (MonadThrow m)
  => ECC.Scalar Curve
  -> ConduitM ByteString ByteString m ()
decrypt :: forall (m :: * -> *).
MonadThrow m =>
Scalar Curve -> ConduitM ByteString ByteString m ()
decrypt Scalar Curve
scalar = do
  ByteString
pointBS <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
pointBinarySize
  Point
point   <- forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail (forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> bs -> CryptoFailable (Point curve)
ECC.decodePoint Proxy Curve
proxy ByteString
pointBS)
  SharedSecret
shared  <- forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail (forall {curve} {proxy :: * -> *}.
EllipticCurveDH curve =>
proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecryptCompat Proxy Curve
proxy Point
point Scalar Curve
scalar)
  let (ByteString
_nonce, ByteString
key) = SharedSecret -> (ByteString, ByteString)
getNonceKey SharedSecret
shared
  forall (m :: * -> *).
MonadThrow m =>
ByteString -> ConduitM ByteString ByteString m ()
ChaCha.decrypt ByteString
key
  where
    deriveDecryptCompat :: proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecryptCompat proxy curve
prx Point curve
p Scalar curve
s = forall {curve} {proxy :: * -> *}.
EllipticCurveDH curve =>
proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecrypt proxy curve
prx Point curve
p Scalar curve
s