{-# LANGUAGE DeriveDataTypeable #-}
module Crypto.Cipher.ChaChaPoly1305.Conduit
( encrypt
, decrypt
, ChaChaException (..)
) where
import Control.Exception (assert)
import Control.Monad.Catch (Exception, MonadThrow, throwM)
import qualified Crypto.Cipher.ChaChaPoly1305 as Cha
import qualified Crypto.Error as CE
import qualified Crypto.MAC.Poly1305 as Poly1305
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Conduit (ConduitM, await, leftover, yield)
import qualified Data.Conduit.Binary as CB
import Data.Typeable (Typeable)
cf :: MonadThrow m
=> (CE.CryptoError -> ChaChaException)
-> CE.CryptoFailable a
-> m a
cf :: forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
_ (CE.CryptoPassed a
x) = forall (m :: * -> *) a. Monad m => a -> m a
return a
x
cf CryptoError -> ChaChaException
f (CE.CryptoFailed CryptoError
e) = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (CryptoError -> ChaChaException
f CryptoError
e)
data ChaChaException
= EncryptNonceException !CE.CryptoError
| EncryptKeyException !CE.CryptoError
| DecryptNonceException !CE.CryptoError
| DecryptKeyException !CE.CryptoError
| MismatchedAuth
deriving (Int -> ChaChaException -> ShowS
[ChaChaException] -> ShowS
ChaChaException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ChaChaException] -> ShowS
$cshowList :: [ChaChaException] -> ShowS
show :: ChaChaException -> String
$cshow :: ChaChaException -> String
showsPrec :: Int -> ChaChaException -> ShowS
$cshowsPrec :: Int -> ChaChaException -> ShowS
Show, Typeable)
instance Exception ChaChaException
encrypt
:: MonadThrow m
=> ByteString
-> ByteString
-> ConduitM ByteString ByteString m ()
encrypt :: forall (m :: * -> *).
MonadThrow m =>
ByteString -> ByteString -> ConduitM ByteString ByteString m ()
encrypt ByteString
nonceBS ByteString
key = do
Nonce
nonce <- forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
EncryptNonceException forall a b. (a -> b) -> a -> b
$ forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
Cha.nonce12 ByteString
nonceBS
State
state0 <- forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
EncryptKeyException forall a b. (a -> b) -> a -> b
$ forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
Cha.initialize ByteString
key Nonce
nonce
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
nonceBS
let loop :: State -> ConduitT o o m ()
loop State
state1 = do
Maybe o
mbs <- forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
case Maybe o
mbs of
Maybe o
Nothing -> forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert forall a b. (a -> b) -> a -> b
$ State -> Auth
Cha.finalize State
state1
Just o
bs -> do
let (o
bs', State
state2) = forall ba. ByteArray ba => ba -> State -> (ba, State)
Cha.encrypt o
bs State
state1
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield o
bs'
State -> ConduitT o o m ()
loop State
state2
forall {m :: * -> *} {o}.
(Monad m, ByteArray o) =>
State -> ConduitT o o m ()
loop forall a b. (a -> b) -> a -> b
$ State -> State
Cha.finalizeAAD State
state0
decrypt
:: MonadThrow m
=> ByteString
-> ConduitM ByteString ByteString m ()
decrypt :: forall (m :: * -> *).
MonadThrow m =>
ByteString -> ConduitM ByteString ByteString m ()
decrypt ByteString
key = do
ByteString
nonceBS <- forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
12
Nonce
nonce <- forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
DecryptNonceException forall a b. (a -> b) -> a -> b
$ forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
Cha.nonce12 forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict ByteString
nonceBS
State
state0 <- forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
DecryptKeyException forall a b. (a -> b) -> a -> b
$ forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
Cha.initialize ByteString
key Nonce
nonce
let loop :: State -> ConduitT ByteString ByteString m ()
loop State
state1 = do
Either ByteString ByteString
ebs <- forall {m :: * -> *} {o}.
Monad m =>
(ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 forall a. a -> a
id
case Either ByteString ByteString
ebs of
Left ByteString
final ->
case forall b. ByteArrayAccess b => b -> CryptoFailable Auth
Poly1305.authTag ByteString
final of
CE.CryptoPassed Auth
final' | State -> Auth
Cha.finalize State
state1 forall a. Eq a => a -> a -> Bool
== Auth
final' -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
CryptoFailable Auth
_ -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ChaChaException
MismatchedAuth
Right ByteString
bs -> do
let (ByteString
bs', State
state2) = forall ba. ByteArray ba => ba -> State -> (ba, State)
Cha.decrypt ByteString
bs State
state1
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs'
State -> ConduitT ByteString ByteString m ()
loop State
state2
forall {m :: * -> *}.
MonadThrow m =>
State -> ConduitT ByteString ByteString m ()
loop forall a b. (a -> b) -> a -> b
$ State -> State
Cha.finalizeAAD State
state0
where
awaitExcept16 :: (ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 ByteString -> ByteString
front = do
Maybe ByteString
mbs <- forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
case Maybe ByteString
mbs of
Maybe ByteString
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
front ByteString
B.empty
Just ByteString
bs -> do
let bs' :: ByteString
bs' = ByteString -> ByteString
front ByteString
bs
if ByteString -> Int
B.length ByteString
bs' forall a. Ord a => a -> a -> Bool
> Int
16
then do
let (ByteString
x, ByteString
y) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
bs' forall a. Num a => a -> a -> a
- Int
16) ByteString
bs'
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ByteString -> Int
B.length ByteString
y forall a. Eq a => a -> a -> Bool
== Int
16) forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
y
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ByteString
x
else (ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 (ByteString -> ByteString -> ByteString
B.append ByteString
bs')