module Crypto.Cipher.ChaChaPoly1305.Conduit
( encrypt
, decrypt
, ChaChaException (..)
) where
import Control.Exception (assert)
import Control.Monad.Catch (Exception, MonadThrow, throwM)
import qualified Crypto.Cipher.ChaChaPoly1305 as Cha
import qualified Crypto.Error as CE
import qualified Crypto.MAC.Poly1305 as Poly1305
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Conduit (ConduitM, await, leftover, yield)
import qualified Data.Conduit.Binary as CB
import Data.Typeable (Typeable)
cf :: MonadThrow m
=> (CE.CryptoError -> ChaChaException)
-> CE.CryptoFailable a
-> m a
cf _ (CE.CryptoPassed x) = return x
cf f (CE.CryptoFailed e) = throwM (f e)
data ChaChaException
= EncryptNonceException !CE.CryptoError
| EncryptKeyException !CE.CryptoError
| DecryptNonceException !CE.CryptoError
| DecryptKeyException !CE.CryptoError
| MismatchedAuth
deriving (Show, Typeable)
instance Exception ChaChaException
encrypt
:: MonadThrow m
=> ByteString
-> ByteString
-> ConduitM ByteString ByteString m ()
encrypt nonceBS key = do
nonce <- cf EncryptNonceException $ Cha.nonce12 nonceBS
state0 <- cf EncryptKeyException $ Cha.initialize key nonce
yield nonceBS
let loop state1 = do
mbs <- await
case mbs of
Nothing -> yield $ BA.convert $ Cha.finalize state1
Just bs -> do
let (bs', state2) = Cha.encrypt bs state1
yield bs'
loop state2
loop $ Cha.finalizeAAD state0
decrypt
:: MonadThrow m
=> ByteString
-> ConduitM ByteString ByteString m ()
decrypt key = do
nonceBS <- CB.take 12
nonce <- cf DecryptNonceException $ Cha.nonce12 $ BL.toStrict nonceBS
state0 <- cf DecryptKeyException $ Cha.initialize key nonce
let loop state1 = do
ebs <- awaitExcept16 id
case ebs of
Left final ->
case Poly1305.authTag final of
CE.CryptoPassed final' | Cha.finalize state1 == final' -> return ()
_ -> throwM MismatchedAuth
Right bs -> do
let (bs', state2) = Cha.decrypt bs state1
yield bs'
loop state2
loop $ Cha.finalizeAAD state0
where
awaitExcept16 front = do
mbs <- await
case mbs of
Nothing -> return $ Left $ front B.empty
Just bs -> do
let bs' = front bs
if B.length bs' > 16
then do
let (x, y) = B.splitAt (B.length bs' 16) bs'
assert (B.length y == 16) leftover y
return $ Right x
else awaitExcept16 (B.append bs')