{-# LANGUAGE DeriveDataTypeable #-} module Crypto.Cipher.ChaChaPoly1305.Conduit ( encrypt , decrypt , ChaChaException (..) ) where import Control.Exception (assert) import Control.Monad.Catch (Exception, MonadThrow, throwM) import qualified Crypto.Cipher.ChaChaPoly1305 as Cha import qualified Crypto.Error as CE import qualified Crypto.MAC.Poly1305 as Poly1305 import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import Data.Conduit (ConduitM, await, leftover, yield) import qualified Data.Conduit.Binary as CB import Data.Typeable (Typeable) cf :: MonadThrow m => (CE.CryptoError -> ChaChaException) -> CE.CryptoFailable a -> m a cf _ (CE.CryptoPassed x) = return x cf f (CE.CryptoFailed e) = throwM (f e) data ChaChaException = EncryptNonceException !CE.CryptoError | EncryptKeyException !CE.CryptoError | DecryptNonceException !CE.CryptoError | DecryptKeyException !CE.CryptoError | MismatchedAuth deriving (Show, Typeable) instance Exception ChaChaException encrypt :: MonadThrow m => ByteString -- ^ nonce (12 random bytes) -> ByteString -- ^ symmetric key (32 bytes) -> ConduitM ByteString ByteString m () encrypt nonceBS key = do nonce <- cf EncryptNonceException $ Cha.nonce12 nonceBS state0 <- cf EncryptKeyException $ Cha.initialize key nonce yield nonceBS let loop state1 = do mbs <- await case mbs of Nothing -> yield $ BA.convert $ Cha.finalize state1 Just bs -> do let (bs', state2) = Cha.encrypt bs state1 yield bs' loop state2 loop $ Cha.finalizeAAD state0 decrypt :: MonadThrow m => ByteString -- ^ symmetric key (32 bytes) -> ConduitM ByteString ByteString m () decrypt key = do nonceBS <- CB.take 12 nonce <- cf DecryptNonceException $ Cha.nonce12 $ BL.toStrict nonceBS state0 <- cf DecryptKeyException $ Cha.initialize key nonce let loop state1 = do ebs <- awaitExcept16 id case ebs of Left final -> case Poly1305.authTag final of CE.CryptoPassed final' | Cha.finalize state1 == final' -> return () _ -> throwM MismatchedAuth Right bs -> do let (bs', state2) = Cha.decrypt bs state1 yield bs' loop state2 loop $ Cha.finalizeAAD state0 where awaitExcept16 front = do mbs <- await case mbs of Nothing -> return $ Left $ front B.empty Just bs -> do let bs' = front bs if B.length bs' > 16 then do let (x, y) = B.splitAt (B.length bs' - 16) bs' assert (B.length y == 16) leftover y return $ Right x else awaitExcept16 (B.append bs')