{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Tahoe.CHK.Cipher (
    Key (keyBytes, keyCipher),
) where

import Control.DeepSeq (NFData)
import Crypto.Cipher.Types (AEAD, BlockCipher (..), Cipher (..))
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.Coerce (coerce)
import GHC.Generics (Generic)

{- | A block cipher key which can be deserialized from or serialized to a
 ByteArray.

 This is a wrapper around Crypto.Cipher.Types.Cipher which does not provide a
 way to recover the original bytes of the key.  We provide this by keeping the
 original bytes around.
-}
data Key cipher = Key {Key cipher -> ScrubbedBytes
keyBytes :: ScrubbedBytes, Key cipher -> cipher
keyCipher :: cipher}

deriving instance Generic (Key cipher)
deriving instance NFData cipher => NFData (Key cipher)

instance forall cipher. Cipher cipher => Cipher (Key cipher) where
    cipherInit :: key -> CryptoFailable (Key cipher)
cipherInit key
bs = ScrubbedBytes -> cipher -> Key cipher
forall cipher. ScrubbedBytes -> cipher -> Key cipher
Key (key -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert key
bs) (cipher -> Key cipher)
-> CryptoFailable cipher -> CryptoFailable (Key cipher)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> key -> CryptoFailable cipher
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit key
bs
    cipherName :: Key cipher -> String
cipherName Key cipher
_ = cipher -> String
forall cipher. Cipher cipher => cipher -> String
cipherName @cipher cipher
forall a. HasCallStack => a
undefined
    cipherKeySize :: Key cipher -> KeySizeSpecifier
cipherKeySize Key cipher
_ = cipher -> KeySizeSpecifier
forall cipher. Cipher cipher => cipher -> KeySizeSpecifier
cipherKeySize @cipher cipher
forall a. HasCallStack => a
undefined

instance forall cipher. BlockCipher cipher => BlockCipher (Key cipher) where
    blockSize :: Key cipher -> Int
blockSize Key cipher
_ = cipher -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize @cipher cipher
forall a. HasCallStack => a
undefined
    ecbEncrypt :: Key cipher -> ba -> ba
ecbEncrypt = cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt (cipher -> ba -> ba)
-> (Key cipher -> cipher) -> Key cipher -> ba -> ba
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key cipher -> cipher
forall cipher. Key cipher -> cipher
keyCipher
    ecbDecrypt :: Key cipher -> ba -> ba
ecbDecrypt = cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbDecrypt (cipher -> ba -> ba)
-> (Key cipher -> cipher) -> Key cipher -> ba -> ba
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key cipher -> cipher
forall cipher. Key cipher -> cipher
keyCipher
    cbcEncrypt :: Key cipher -> IV (Key cipher) -> ba -> ba
cbcEncrypt (Key ScrubbedBytes
_ cipher
cipher) IV (Key cipher)
iv = cipher -> IV cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcEncrypt cipher
cipher (IV (Key cipher) -> IV cipher
coerce IV (Key cipher)
iv)
    cbcDecrypt :: Key cipher -> IV (Key cipher) -> ba -> ba
cbcDecrypt (Key ScrubbedBytes
_ cipher
cipher) IV (Key cipher)
iv = cipher -> IV cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcDecrypt cipher
cipher (IV (Key cipher) -> IV cipher
coerce IV (Key cipher)
iv)

    cfbEncrypt :: Key cipher -> IV (Key cipher) -> ba -> ba
cfbEncrypt (Key ScrubbedBytes
_ cipher
cipher) IV (Key cipher)
iv = cipher -> IV cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cfbEncrypt cipher
cipher (IV (Key cipher) -> IV cipher
coerce IV (Key cipher)
iv)
    cfbDecrypt :: Key cipher -> IV (Key cipher) -> ba -> ba
cfbDecrypt (Key ScrubbedBytes
_ cipher
cipher) IV (Key cipher)
iv = cipher -> IV cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cfbDecrypt cipher
cipher (IV (Key cipher) -> IV cipher
coerce IV (Key cipher)
iv)
    ctrCombine :: Key cipher -> IV (Key cipher) -> ba -> ba
ctrCombine (Key ScrubbedBytes
_ cipher
cipher) IV (Key cipher)
iv = cipher -> IV cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
ctrCombine cipher
cipher (IV (Key cipher) -> IV cipher
coerce IV (Key cipher)
iv)

    aeadInit :: AEADMode -> Key cipher -> iv -> CryptoFailable (AEAD (Key cipher))
aeadInit AEADMode
mode (Key ScrubbedBytes
_ cipher
cipher) iv
iv = AEAD cipher -> AEAD (Key cipher)
wrap (AEAD cipher -> AEAD (Key cipher))
-> CryptoFailable (AEAD cipher)
-> CryptoFailable (AEAD (Key cipher))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
mode cipher
cipher iv
iv
      where
        wrap :: AEAD cipher -> AEAD (Key cipher)
wrap = Coercible (AEAD cipher) (AEAD (Key cipher)) =>
AEAD cipher -> AEAD (Key cipher)
coerce @(AEAD cipher) @(AEAD (Key cipher))

instance BA.ByteArrayAccess (Key cipher) where
    length :: Key cipher -> Int
length (Key ScrubbedBytes
ba cipher
_) = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
ba
    withByteArray :: Key cipher -> (Ptr p -> IO a) -> IO a
withByteArray (Key ScrubbedBytes
ba cipher
_) = ScrubbedBytes -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
BA.withByteArray ScrubbedBytes
ba