{-# LANGUAGE OverloadedStrings #-}
module Dhall.Secret.Age
  ( encrypt,
    decrypt,
    generateX25519Identity,
    parseRecipient,
    parseIdentity,
    toRecipient,
  )
where
import qualified Codec.Binary.Bech32          as Bech32
import qualified Crypto.Cipher.ChaChaPoly1305 as CC
import           Crypto.Error                 (CryptoError (..),
                                               CryptoFailable (..),
                                               eitherCryptoError,
                                               throwCryptoErrorIO)
import           Crypto.Hash                  (SHA256)
import           Crypto.KDF.HKDF              (PRK)
import qualified Crypto.KDF.HKDF              as HKDF
import           Crypto.MAC.HMAC              (HMAC, hmac)
import qualified Crypto.PubKey.Curve25519     as X25519
import           Crypto.Random                (MonadRandom (getRandomBytes))
import           Data.ByteArray               (ByteArrayAccess, convert)
import           Data.ByteString              (ByteString, intercalate)
import qualified Data.ByteString              as BS
import qualified Data.ByteString.Char8              as BC
import qualified Data.ByteString.Base64       as B64
import           Data.Either                  (isRight)
import           Data.List                    (find)
import           Data.Maybe                   (fromMaybe)
import           Data.PEM                     (PEM (..), pemParseBS, pemWriteBS)
import           Data.Text                    (Text)
import qualified Data.Text                    as T

data Stanza = Stanza
  { Stanza -> ByteString
stzType :: ByteString,
    Stanza -> [ByteString]
stzArgs :: [ByteString],
    Stanza -> ByteString
stzBody :: ByteString
  }
  deriving (Int -> Stanza -> ShowS
[Stanza] -> ShowS
Stanza -> String
(Int -> Stanza -> ShowS)
-> (Stanza -> String) -> ([Stanza] -> ShowS) -> Show Stanza
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Stanza -> ShowS
showsPrec :: Int -> Stanza -> ShowS
$cshow :: Stanza -> String
show :: Stanza -> String
$cshowList :: [Stanza] -> ShowS
showList :: [Stanza] -> ShowS
Show)

data X25519Recipient = X25519Recipient X25519.PublicKey

instance Show X25519Recipient where
  show :: X25519Recipient -> String
show (X25519Recipient PublicKey
pub) = Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> PublicKey -> Text
forall b. ByteArrayAccess b => Text -> b -> Text
b32 Text
"age" PublicKey
pub

data X25519Identity = X25519Identity X25519.PublicKey X25519.SecretKey

instance Show X25519Identity where
  show :: X25519Identity -> String
show (X25519Identity PublicKey
_ SecretKey
sec) = Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text
T.toUpper (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Text -> SecretKey -> Text
forall b. ByteArrayAccess b => Text -> b -> Text
b32 Text
"AGE-SECRET-KEY-" SecretKey
sec

data Header = Header [Stanza] ByteString

data CipherBlock = Cipher Header ByteString ByteString

encrypt :: [X25519Recipient] -> ByteString -> IO ByteString
encrypt :: [X25519Recipient] -> ByteString -> IO ByteString
encrypt [X25519Recipient]
recipients ByteString
msg = do
  ByteString
fileKey <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
16 :: IO ByteString
  ByteString
nonce <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
16 :: IO ByteString
  [Stanza]
stanzas <- (X25519Recipient -> IO Stanza) -> [X25519Recipient] -> IO [Stanza]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (ByteString -> X25519Recipient -> IO Stanza
mkStanza ByteString
fileKey) [X25519Recipient]
recipients
  ByteString
body <- ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunks ByteString
BS.empty (ByteString -> ByteString -> ByteString
payloadKey ByteString
nonce ByteString
fileKey) (Int -> ByteString
zeroNonceOf Int
11) ByteString
msg
  ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ PEM -> ByteString
pemWriteBS (PEM -> ByteString) -> PEM -> ByteString
forall a b. (a -> b) -> a -> b
$ PEM {pemName :: String
pemName = String
"AGE ENCRYPTED FILE", pemHeader :: [(String, ByteString)]
pemHeader = [], pemContent :: ByteString
pemContent = ByteString -> [Stanza] -> ByteString
mkHeader ByteString
fileKey [Stanza]
stanzas ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
body}

decrypt :: ByteString -> [X25519Identity] -> IO ByteString
decrypt :: ByteString -> [X25519Identity] -> IO ByteString
decrypt ByteString
ciphertext [X25519Identity]
identities = do
  (Cipher Header
header ByteString
nonce ByteString
body) <- (String -> IO CipherBlock)
-> (CipherBlock -> IO CipherBlock)
-> Either String CipherBlock
-> IO CipherBlock
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> IO CipherBlock
forall a. HasCallStack => String -> a
error CipherBlock -> IO CipherBlock
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String CipherBlock -> IO CipherBlock)
-> Either String CipherBlock -> IO CipherBlock
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String CipherBlock
parseCipher ByteString
ciphertext
  let Header [Stanza]
stz ByteString
mac = Header
header
  let possibleKeys :: [Either CryptoError ByteString]
possibleKeys = [X25519Identity] -> Header -> [Either CryptoError ByteString]
findFileKey [X25519Identity]
identities Header
header
  case (Either CryptoError ByteString -> Bool)
-> [Either CryptoError ByteString]
-> Maybe (Either CryptoError ByteString)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Either CryptoError ByteString -> Bool
forall a b. Either a b -> Bool
isRight ([Either CryptoError ByteString]
 -> Maybe (Either CryptoError ByteString))
-> [Either CryptoError ByteString]
-> Maybe (Either CryptoError ByteString)
forall a b. (a -> b) -> a -> b
$ [Either CryptoError ByteString]
possibleKeys of
    Just (Right ByteString
key) -> do
      let (ByteString
headerNoMac, ByteString
macGot) = ByteString -> [Stanza] -> (ByteString, ByteString)
mkHeaderMac ByteString
key [Stanza]
stz
      if ByteString
macGot ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
mac
        then ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunks ByteString
BS.empty (ByteString -> ByteString -> ByteString
payloadKey ByteString
nonce ByteString
key) (Int -> ByteString
zeroNonceOf Int
11) ByteString
body
        else String -> IO ByteString
forall a. HasCallStack => String -> a
error (String -> IO ByteString) -> String -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> String
forall a. Show a => a -> String
show (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ByteString
"Header MAC not match" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
headerNoMac ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
macGot
    Maybe (Either CryptoError ByteString)
_ -> String -> IO ByteString
forall a. HasCallStack => String -> a
error String
"No file key found"

generateX25519Identity :: IO X25519Identity
generateX25519Identity :: IO X25519Identity
generateX25519Identity = do
  SecretKey
sec <- IO SecretKey
forall (m :: * -> *). MonadRandom m => m SecretKey
X25519.generateSecretKey
  X25519Identity -> IO X25519Identity
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (X25519Identity -> IO X25519Identity)
-> X25519Identity -> IO X25519Identity
forall a b. (a -> b) -> a -> b
$ PublicKey -> SecretKey -> X25519Identity
X25519Identity (SecretKey -> PublicKey
X25519.toPublic SecretKey
sec) SecretKey
sec

parseRecipient :: Text -> IO X25519Recipient
parseRecipient :: Text -> IO X25519Recipient
parseRecipient Text
r = PublicKey -> X25519Recipient
X25519Recipient (PublicKey -> X25519Recipient)
-> IO PublicKey -> IO X25519Recipient
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CryptoFailable PublicKey -> IO PublicKey
forall a. CryptoFailable a -> IO a
throwCryptoErrorIO (ByteString -> CryptoFailable PublicKey
forall bs. ByteArrayAccess bs => bs -> CryptoFailable PublicKey
X25519.publicKey (ByteString -> CryptoFailable PublicKey)
-> ByteString -> CryptoFailable PublicKey
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
b32dec Text
r)

parseIdentity :: Text -> IO X25519Identity
parseIdentity :: Text -> IO X25519Identity
parseIdentity Text
i = CryptoFailable X25519Identity -> IO X25519Identity
forall a. CryptoFailable a -> IO a
throwCryptoErrorIO (CryptoFailable X25519Identity -> IO X25519Identity)
-> CryptoFailable X25519Identity -> IO X25519Identity
forall a b. (a -> b) -> a -> b
$ do
  SecretKey
key <- ByteString -> CryptoFailable SecretKey
forall bs. ByteArrayAccess bs => bs -> CryptoFailable SecretKey
X25519.secretKey (Text -> ByteString
b32dec Text
i)
  X25519Identity -> CryptoFailable X25519Identity
forall a. a -> CryptoFailable a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (X25519Identity -> CryptoFailable X25519Identity)
-> X25519Identity -> CryptoFailable X25519Identity
forall a b. (a -> b) -> a -> b
$ PublicKey -> SecretKey -> X25519Identity
X25519Identity (SecretKey -> PublicKey
X25519.toPublic SecretKey
key) SecretKey
key

decryptChunks :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunks :: ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunks ByteString
acc ByteString
key ByteString
nonce ByteString
body = case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16) ByteString
body of
  (ByteString
head', ByteString
tail') | ByteString
tail' ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
BS.empty -> (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunk ByteString
key ByteString
nonce ByteString
head' ([Word8] -> ByteString
BS.pack [Word8
1])
  (ByteString
head', ByteString
tail') -> do
    ByteString
decrypted <- ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunk ByteString
key ByteString
nonce ByteString
head' ([Word8] -> ByteString
BS.pack [Word8
0])
    ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunks (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
decrypted) ByteString
key (ByteString -> ByteString
incNonce ByteString
nonce) ByteString
tail'

encryptChunks :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunks :: ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunks ByteString
acc ByteString
key ByteString
nonce ByteString
msg = case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024) ByteString
msg of
  (ByteString
head', ByteString
tail') | ByteString
tail' ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
BS.empty -> (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunk ByteString
key ByteString
nonce ByteString
head' ([Word8] -> ByteString
BS.pack [Word8
1])
  (ByteString
head', ByteString
tail') -> do
    ByteString
encrypted <- ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunk ByteString
key ByteString
nonce ByteString
head' ([Word8] -> ByteString
BS.pack [Word8
0])
    ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunks (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
encrypted) ByteString
key (ByteString -> ByteString
incNonce ByteString
nonce) ByteString
tail'

encryptChunk :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunk :: ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunk ByteString
key ByteString
nonce ByteString
msg ByteString
isFinal = do
  State
st <- CryptoFailable State -> IO State
forall a. CryptoFailable a -> IO a
throwCryptoErrorIO (CryptoFailable State -> IO State)
-> CryptoFailable State -> IO State
forall a b. (a -> b) -> a -> b
$ do
    Nonce
payloadNonce <- ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
CC.nonce12 (ByteString -> CryptoFailable Nonce)
-> ByteString -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ (ByteString
nonce ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
isFinal)
    State -> State
CC.finalizeAAD (State -> State) -> CryptoFailable State -> CryptoFailable State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
CC.initialize ByteString
key Nonce
payloadNonce
  let (ByteString
e, State
st1) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
CC.encrypt ByteString
msg State
st
  let tag :: Auth
tag = State -> Auth
CC.finalize State
st1
  ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
e ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (Auth -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Auth
tag)

decryptChunk :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunk :: ByteString
-> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunk ByteString
key ByteString
nonce ByteString
cipherblob ByteString
isFinal = do
  State
st1 <- CryptoFailable State -> IO State
forall a. CryptoFailable a -> IO a
throwCryptoErrorIO (CryptoFailable State -> IO State)
-> CryptoFailable State -> IO State
forall a b. (a -> b) -> a -> b
$ do
    Nonce
payloadNonce <- ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
CC.nonce12 (ByteString -> CryptoFailable Nonce)
-> ByteString -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ (ByteString
nonce ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
isFinal)
    State -> State
CC.finalizeAAD (State -> State) -> CryptoFailable State -> CryptoFailable State
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
CC.initialize ByteString
key Nonce
payloadNonce
  let (ByteString
msg, ByteString
tag) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
cipherblob Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
cipherblob
  let (ByteString
d, State
st2) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
CC.decrypt ByteString
msg State
st1
  let authtag :: Auth
authtag = State -> Auth
CC.finalize State
st2
  if (Auth -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Auth
authtag) ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag then ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
d else String -> IO ByteString
forall a. HasCallStack => String -> a
error String
"Invalid auth tag"

parseCipher :: ByteString -> Either String CipherBlock
parseCipher :: ByteString -> Either String CipherBlock
parseCipher ByteString
ct = do
  ByteString
content <- PEM -> ByteString
pemContent (PEM -> ByteString) -> ([PEM] -> PEM) -> [PEM] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PEM] -> PEM
forall a. HasCallStack => [a] -> a
head ([PEM] -> ByteString)
-> Either String [PEM] -> Either String ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either String [PEM]
pemParseBS ByteString
ct
  let (ByteString
_, ByteString
rest) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x0a) ByteString
content
  (Header
header, ByteString
rest2) <- Header -> ByteString -> Either String (Header, ByteString)
parseHeader ([Stanza] -> ByteString -> Header
Header [] ByteString
"") (Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
rest)
  let (ByteString
nonce, ByteString
body) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
16 ByteString
rest2
  CipherBlock -> Either String CipherBlock
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CipherBlock -> Either String CipherBlock)
-> CipherBlock -> Either String CipherBlock
forall a b. (a -> b) -> a -> b
$ Header -> ByteString -> ByteString -> CipherBlock
Cipher Header
header ByteString
nonce ByteString
body

parseHeader :: Header -> ByteString -> Either String (Header, ByteString)
parseHeader :: Header -> ByteString -> Either String (Header, ByteString)
parseHeader (Header [Stanza]
stz ByteString
mac) ByteString
content = do
  case Int -> ByteString -> ByteString
BS.take Int
3 ByteString
content of
    ByteString
"---" ->
      let (ByteString
mac', ByteString
body) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isLF (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString
content
       in (Header, ByteString) -> Either String (Header, ByteString)
forall a b. b -> Either a b
Right ((Header, ByteString) -> Either String (Header, ByteString))
-> (Header, ByteString) -> Either String (Header, ByteString)
forall a b. (a -> b) -> a -> b
$ ([Stanza] -> ByteString -> Header
Header ([Stanza] -> [Stanza]
forall a. [a] -> [a]
reverse [Stanza]
stz) (ByteString -> ByteString
B64.decodeLenient (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop Int
4 ByteString
mac'), Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
body)
    ByteString
"-> " ->
      let (ByteString
recipients, ByteString
rest1) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isLF (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop Int
3 ByteString
content
          (ByteString
fileKey, ByteString
rest2) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isLF (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
rest1
          (ByteString
stztype, ByteString
rest11) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isSpace ByteString
recipients
          stzarg :: ByteString
stzarg = Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
rest11
          st :: Stanza
st = Stanza {stzType :: ByteString
stzType = ByteString
stztype, stzArgs :: [ByteString]
stzArgs = [ByteString
stzarg], stzBody :: ByteString
stzBody = ByteString -> ByteString
B64.decodeLenient ByteString
fileKey}
       in Header -> ByteString -> Either String (Header, ByteString)
parseHeader ([Stanza] -> ByteString -> Header
Header (Stanza
st Stanza -> [Stanza] -> [Stanza]
forall a. a -> [a] -> [a]
: [Stanza]
stz) ByteString
mac) (Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
rest2)
    ByteString
_ -> String -> Either String (Header, ByteString)
forall a b. a -> Either a b
Left String
"invalid headers"
  where
    isLF :: Word8 -> Bool
isLF = (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x0a)
    isSpace :: Word8 -> Bool
isSpace = (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x20)

findFileKey :: [X25519Identity] -> Header -> [Either CryptoError ByteString]
findFileKey :: [X25519Identity] -> Header -> [Either CryptoError ByteString]
findFileKey [X25519Identity]
identities (Header [Stanza]
stanza ByteString
_mac) = X25519Identity -> Stanza -> Either CryptoError ByteString
hasKey (X25519Identity -> Stanza -> Either CryptoError ByteString)
-> [X25519Identity] -> [Stanza -> Either CryptoError ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [X25519Identity]
identities [Stanza -> Either CryptoError ByteString]
-> [Stanza] -> [Either CryptoError ByteString]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Stanza]
stanza
  where
    hasKey :: X25519Identity -> Stanza -> Either CryptoError ByteString
    hasKey :: X25519Identity -> Stanza -> Either CryptoError ByteString
hasKey (X25519Identity PublicKey
pk SecretKey
sec) Stanza
stz = CryptoFailable ByteString -> Either CryptoError ByteString
forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError (CryptoFailable ByteString -> Either CryptoError ByteString)
-> CryptoFailable ByteString -> Either CryptoError ByteString
forall a b. (a -> b) -> a -> b
$ do
      let theirPkBs :: ByteString
theirPkBs = ByteString -> ByteString
B64.decodeLenient (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
forall a. HasCallStack => [a] -> a
head (Stanza -> [ByteString]
stzArgs Stanza
stz)
      PublicKey
theirPk <- ByteString -> CryptoFailable PublicKey
forall bs. ByteArrayAccess bs => bs -> CryptoFailable PublicKey
X25519.publicKey ByteString
theirPkBs
      let shareKey :: DhSecret
shareKey = PublicKey -> SecretKey -> DhSecret
X25519.dh PublicKey
theirPk SecretKey
sec
      let salt :: ByteString
salt = (PublicKey -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert PublicKey
theirPk) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (PublicKey -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert PublicKey
pk)
      let wrappingKey :: ByteString
wrappingKey = ByteString -> ByteString -> ByteString -> ByteString
hkdf ByteString
"age-encryption.org/v1/X25519" (DhSecret -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert DhSecret
shareKey) ByteString
salt
      Nonce
nonce <- ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
CC.nonce12 (Int -> ByteString
zeroNonceOf Int
12)
      State
st0 <- ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
CC.initialize ByteString
wrappingKey Nonce
nonce
      let fileKey :: ByteString
fileKey = Stanza -> ByteString
stzBody Stanza
stz
      let (ByteString
e, ByteString
tag) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
fileKey Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
fileKey
      let (ByteString
d, State
st1) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
CC.decrypt ByteString
e State
st0
      let dtag :: Auth
dtag = State -> Auth
CC.finalize State
st1
      if (Auth -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Auth
dtag) ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag then ByteString -> CryptoFailable ByteString
forall a. a -> CryptoFailable a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
d else CryptoError -> CryptoFailable ByteString
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_AuthenticationTagSizeInvalid

toRecipient :: X25519Identity -> X25519Recipient
toRecipient :: X25519Identity -> X25519Recipient
toRecipient (X25519Identity PublicKey
pub SecretKey
_) = PublicKey -> X25519Recipient
X25519Recipient PublicKey
pub

b32 :: (ByteArrayAccess b) => Text -> b -> Text
b32 :: forall b. ByteArrayAccess b => Text -> b -> Text
b32 Text
header b
b = case Text -> Either HumanReadablePartError HumanReadablePart
Bech32.humanReadablePartFromText Text
header of
  Left HumanReadablePartError
e -> String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ HumanReadablePartError -> String
forall a. Show a => a -> String
show HumanReadablePartError
e
  Right HumanReadablePart
header' -> case HumanReadablePart -> DataPart -> Either EncodingError Text
Bech32.encode HumanReadablePart
header' (ByteString -> DataPart
Bech32.dataPartFromBytes (b -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert b
b)) of
    Left EncodingError
e  -> String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ EncodingError -> String
forall a. Show a => a -> String
show EncodingError
e
    Right Text
t -> Text
t

b32dec :: Text -> ByteString
b32dec :: Text -> ByteString
b32dec Text
r = case Text -> Either DecodingError (HumanReadablePart, DataPart)
Bech32.decode Text
r of
  Left DecodingError
_ -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"Cannot decode bech32"
  Right (HumanReadablePart
_, DataPart
d) -> ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe (String -> ByteString
forall a. HasCallStack => String -> a
error String
"Cannot extract bech32 data") (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ DataPart -> Maybe ByteString
Bech32.dataPartToBytes DataPart
d

mkStanza :: ByteString -> X25519Recipient -> IO Stanza
mkStanza :: ByteString -> X25519Recipient -> IO Stanza
mkStanza ByteString
fileKey (X25519Recipient PublicKey
theirPK) = do
  SecretKey
ourKey <- IO SecretKey
forall (m :: * -> *). MonadRandom m => m SecretKey
X25519.generateSecretKey
  let ourPK :: PublicKey
ourPK = SecretKey -> PublicKey
X25519.toPublic SecretKey
ourKey
  let shareKey :: DhSecret
shareKey = PublicKey -> SecretKey -> DhSecret
X25519.dh PublicKey
theirPK SecretKey
ourKey
  let salt :: ByteString
salt = (PublicKey -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert PublicKey
ourPK) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (PublicKey -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert PublicKey
theirPK) :: ByteString
  let wrappingKey :: ByteString
wrappingKey = ByteString -> ByteString -> ByteString -> ByteString
hkdf ByteString
"age-encryption.org/v1/X25519" (DhSecret -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert DhSecret
shareKey) ByteString
salt
  ByteString
body <- CryptoFailable ByteString -> IO ByteString
forall a. CryptoFailable a -> IO a
throwCryptoErrorIO (CryptoFailable ByteString -> IO ByteString)
-> CryptoFailable ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
    Nonce
nonce <- ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
CC.nonce12 ([Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
take Int
12 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ Word8 -> [Word8]
forall a. a -> [a]
repeat Word8
0)
    State
st0 <- ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
CC.initialize ByteString
wrappingKey Nonce
nonce
    let (ByteString
e, State
st1) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
CC.encrypt ByteString
fileKey State
st0
    ByteString -> CryptoFailable ByteString
forall a. a -> CryptoFailable a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> CryptoFailable ByteString)
-> ByteString -> CryptoFailable ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
e ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (Auth -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (Auth -> ByteString) -> Auth -> ByteString
forall a b. (a -> b) -> a -> b
$ State -> Auth
CC.finalize State
st1)
  Stanza -> IO Stanza
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stanza {stzType :: ByteString
stzType = ByteString
"X25519", stzBody :: ByteString
stzBody = ByteString
body, stzArgs :: [ByteString]
stzArgs = [ByteString -> ByteString
encodeBase64Unpadded (PublicKey -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert PublicKey
ourPK)]}

marshalStanza :: Stanza -> ByteString
marshalStanza :: Stanza -> ByteString
marshalStanza Stanza
stanza =
  let prefix :: ByteString
prefix = ByteString
"-> " :: ByteString
      body :: ByteString
body = ByteString -> ByteString
encodeBase64Unpadded (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Stanza -> ByteString
stzBody Stanza
stanza
      argLine :: ByteString
argLine = ByteString
prefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Stanza -> ByteString
stzType Stanza
stanza ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> [ByteString] -> ByteString
intercalate ByteString
" " (Stanza -> [ByteString]
stzArgs Stanza
stanza) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n"
   in ByteString
argLine
        ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
wrap64b ByteString
body
        ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n"

mkHeader :: ByteString -> [Stanza] -> ByteString
mkHeader :: ByteString -> [Stanza] -> ByteString
mkHeader ByteString
fileKey [Stanza]
recipients =
  let (ByteString
headerNoMac, ByteString
mac) = ByteString -> [Stanza] -> (ByteString, ByteString)
mkHeaderMac ByteString
fileKey [Stanza]
recipients
   in ByteString
headerNoMac ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (ByteString -> ByteString
encodeBase64Unpadded ByteString
mac) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n"

mkHeaderMac :: ByteString -> [Stanza] -> (ByteString, ByteString)
mkHeaderMac :: ByteString -> [Stanza] -> (ByteString, ByteString)
mkHeaderMac ByteString
fileKey [Stanza]
recipients =
  let intro :: ByteString
intro = ByteString
"age-encryption.org/v1\n" :: ByteString
      macKey :: ByteString
macKey = ByteString -> ByteString -> ByteString -> ByteString
hkdf ByteString
"header" ByteString
fileKey ByteString
""
      footer :: ByteString
footer = ByteString
"---" :: ByteString
      stanza :: ByteString
stanza = [ByteString] -> ByteString
BS.concat (Stanza -> ByteString
marshalStanza (Stanza -> ByteString) -> [Stanza] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stanza]
recipients)
      headerNoMac :: ByteString
headerNoMac = ByteString
intro ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
stanza ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
footer
      mac :: ByteString
mac = HMAC SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ByteString -> ByteString -> HMAC SHA256
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac ByteString
macKey ByteString
headerNoMac :: HMAC SHA256) :: ByteString
   in (ByteString
headerNoMac, ByteString
mac)

hkdf :: ByteString -> ByteString -> ByteString -> ByteString
hkdf :: ByteString -> ByteString -> ByteString -> ByteString
hkdf ByteString
info ByteString
key ByteString
salt = PRK SHA256 -> ByteString -> Int -> ByteString
forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
HKDF.expand (ByteString -> ByteString -> PRK SHA256
forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
HKDF.extract ByteString
salt ByteString
key :: PRK SHA256) ByteString
info Int
32

incNonce :: ByteString -> ByteString
incNonce :: ByteString -> ByteString
incNonce ByteString
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString)
-> ((Bool, [Word8]) -> [Word8]) -> (Bool, [Word8]) -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, [Word8]) -> [Word8]
forall a b. (a, b) -> b
snd ((Bool, [Word8]) -> ByteString) -> (Bool, [Word8]) -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> (Bool, [Word8]) -> (Bool, [Word8]))
-> (Bool, [Word8]) -> [Word8] -> (Bool, [Word8])
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Word8 -> (Bool, [Word8]) -> (Bool, [Word8])
forall {a}. (Eq a, Num a) => a -> (Bool, [a]) -> (Bool, [a])
inc1 (Bool
True, []) (ByteString -> [Word8]
BS.unpack ByteString
n)
  where
    inc1 :: a -> (Bool, [a]) -> (Bool, [a])
inc1 a
cur (Bool
True, [a]
acc)  = (a
cur a -> a -> a
forall a. Num a => a -> a -> a
+ a
1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0, (a
cur a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)
    inc1 a
cur (Bool
False, [a]
acc) = (Bool
False, a
cur a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)

zeroNonceOf :: Int -> ByteString
zeroNonceOf :: Int -> ByteString
zeroNonceOf Int
n = [Word8] -> ByteString
BS.pack (Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
take Int
n ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ Word8 -> [Word8]
forall a. a -> [a]
repeat Word8
0)

wrap64b :: ByteString -> ByteString
wrap64b :: ByteString -> ByteString
wrap64b ByteString
bs =
  let (ByteString
head', ByteString
tail') = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
64 ByteString
bs
   in if (ByteString -> Int
BS.length ByteString
tail' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)
        then ByteString
head'
        else ByteString
head' ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
wrap64b ByteString
tail'

payloadKey :: ByteString -> ByteString -> ByteString
payloadKey :: ByteString -> ByteString -> ByteString
payloadKey ByteString
nonce ByteString
filekey = PRK SHA256 -> ByteString -> Int -> ByteString
forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
HKDF.expand (ByteString -> ByteString -> PRK SHA256
forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
HKDF.extract ByteString
nonce ByteString
filekey :: PRK SHA256) (ByteString
"payload" :: ByteString) Int
32

encodeBase64Unpadded :: ByteString -> ByteString
encodeBase64Unpadded :: ByteString -> ByteString
encodeBase64Unpadded = (Char -> Bool) -> ByteString -> ByteString
BC.takeWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'=') (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B64.encode