-- |
-- Module      : Crypto.Cipher.AES
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : stable
-- Portability : good
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.Cipher.AES
    ( AES128
    , AES192
    , AES256
    ) where

import Crypto.Error
import Crypto.Cipher.Types
import Crypto.Cipher.Utils
import Crypto.Cipher.Types.Block
import Crypto.Cipher.AES.Primitive
import Crypto.Internal.Imports

-- | AES with 128 bit key
newtype AES128 = AES128 AES
    deriving (AES128 -> ()
(AES128 -> ()) -> NFData AES128
forall a. (a -> ()) -> NFData a
$crnf :: AES128 -> ()
rnf :: AES128 -> ()
NFData)

-- | AES with 192 bit key
newtype AES192 = AES192 AES
    deriving (AES192 -> ()
(AES192 -> ()) -> NFData AES192
forall a. (a -> ()) -> NFData a
$crnf :: AES192 -> ()
rnf :: AES192 -> ()
NFData)

-- | AES with 256 bit key
newtype AES256 = AES256 AES
    deriving (AES256 -> ()
(AES256 -> ()) -> NFData AES256
forall a. (a -> ()) -> NFData a
$crnf :: AES256 -> ()
rnf :: AES256 -> ()
NFData)

instance Cipher AES128 where
    cipherName :: AES128 -> String
cipherName    AES128
_ = String
"AES128"
    cipherKeySize :: AES128 -> KeySizeSpecifier
cipherKeySize AES128
_ = Int -> KeySizeSpecifier
KeySizeFixed Int
16
    cipherInit :: forall key. ByteArray key => key -> CryptoFailable AES128
cipherInit key
k    = AES -> AES128
AES128 (AES -> AES128) -> CryptoFailable AES -> CryptoFailable AES128
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (key -> CryptoFailable AES
forall key. ByteArrayAccess key => key -> CryptoFailable AES
initAES (key -> CryptoFailable AES)
-> CryptoFailable key -> CryptoFailable AES
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AES128 -> key -> CryptoFailable key
forall key cipher.
(ByteArrayAccess key, Cipher cipher) =>
cipher -> key -> CryptoFailable key
validateKeySize (AES128
forall a. HasCallStack => a
undefined :: AES128) key
k)

instance Cipher AES192 where
    cipherName :: AES192 -> String
cipherName    AES192
_ = String
"AES192"
    cipherKeySize :: AES192 -> KeySizeSpecifier
cipherKeySize AES192
_ = Int -> KeySizeSpecifier
KeySizeFixed Int
24
    cipherInit :: forall key. ByteArray key => key -> CryptoFailable AES192
cipherInit key
k    = AES -> AES192
AES192 (AES -> AES192) -> CryptoFailable AES -> CryptoFailable AES192
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (key -> CryptoFailable AES
forall key. ByteArrayAccess key => key -> CryptoFailable AES
initAES (key -> CryptoFailable AES)
-> CryptoFailable key -> CryptoFailable AES
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AES192 -> key -> CryptoFailable key
forall key cipher.
(ByteArrayAccess key, Cipher cipher) =>
cipher -> key -> CryptoFailable key
validateKeySize (AES192
forall a. HasCallStack => a
undefined :: AES192) key
k)

instance Cipher AES256 where
    cipherName :: AES256 -> String
cipherName    AES256
_ = String
"AES256"
    cipherKeySize :: AES256 -> KeySizeSpecifier
cipherKeySize AES256
_ = Int -> KeySizeSpecifier
KeySizeFixed Int
32
    cipherInit :: forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit key
k    = AES -> AES256
AES256 (AES -> AES256) -> CryptoFailable AES -> CryptoFailable AES256
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (key -> CryptoFailable AES
forall key. ByteArrayAccess key => key -> CryptoFailable AES
initAES (key -> CryptoFailable AES)
-> CryptoFailable key -> CryptoFailable AES
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AES256 -> key -> CryptoFailable key
forall key cipher.
(ByteArrayAccess key, Cipher cipher) =>
cipher -> key -> CryptoFailable key
validateKeySize (AES256
forall a. HasCallStack => a
undefined :: AES256) key
k)


#define INSTANCE_BLOCKCIPHER(CSTR) \
instance BlockCipher CSTR where \
    { blockSize _ = 16 \
    ; ecbEncrypt (CSTR aes) = encryptECB aes \
    ; ecbDecrypt (CSTR aes) = decryptECB aes \
    ; cbcEncrypt (CSTR aes) (IV iv) = encryptCBC aes (IV iv) \
    ; cbcDecrypt (CSTR aes) (IV iv) = decryptCBC aes (IV iv) \
    ; ctrCombine (CSTR aes) (IV iv) = encryptCTR aes (IV iv) \
    ; aeadInit AEAD_GCM (CSTR aes) iv = CryptoPassed $ AEAD (gcmMode aes) (gcmInit aes iv) \
    ; aeadInit AEAD_OCB (CSTR aes) iv = CryptoPassed $ AEAD (ocbMode aes) (ocbInit aes iv) \
    ; aeadInit (AEAD_CCM n m l) (CSTR aes) iv = AEAD (ccmMode aes) <$> ccmInit aes iv n m l \
    ; aeadInit _        _          _  = CryptoFailed CryptoError_AEADModeNotSupported \
    }; \
instance BlockCipher128 CSTR where \
    { xtsEncrypt (CSTR aes1, CSTR aes2) (IV iv) = encryptXTS (aes1,aes2) (IV iv) \
    ; xtsDecrypt (CSTR aes1, CSTR aes2) (IV iv) = decryptXTS (aes1,aes2) (IV iv) \
    };

INSTANCE_BLOCKCIPHER(AES128)
INSTANCE_BLOCKCIPHER(AES192)
INSTANCE_BLOCKCIPHER(AES256)