-- |
-- Module      : Crypto.MAC.CMAC
-- License     : BSD-style
-- Maintainer  : Kei Hibino <ex8k.hibino@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Provide the CMAC (Cipher based Message Authentification Code) base algorithm.
-- <http://en.wikipedia.org/wiki/CMAC>
-- <http://csrc.nist.gov/publications/nistpubs/800-38B/SP_800-38B.pdf>
--
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.MAC.CMAC
    ( cmac
    , CMAC
    , subKeys
    ) where

import           Data.Word
import           Data.Bits (setBit, testBit, shiftL)
import           Data.List (foldl')

import           Crypto.Cipher.Types
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes)
import qualified Crypto.Internal.ByteArray as B

-- | Authentication code
newtype CMAC a = CMAC Bytes
    deriving (CMAC a -> Int
(CMAC a -> Int)
-> (forall p a. CMAC a -> (Ptr p -> IO a) -> IO a)
-> (forall p. CMAC a -> Ptr p -> IO ())
-> ByteArrayAccess (CMAC a)
forall a. CMAC a -> Int
forall p. CMAC a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall a p. CMAC a -> Ptr p -> IO ()
forall p a. CMAC a -> (Ptr p -> IO a) -> IO a
forall a p a. CMAC a -> (Ptr p -> IO a) -> IO a
$clength :: forall a. CMAC a -> Int
length :: CMAC a -> Int
$cwithByteArray :: forall a p a. CMAC a -> (Ptr p -> IO a) -> IO a
withByteArray :: forall p a. CMAC a -> (Ptr p -> IO a) -> IO a
$ccopyByteArrayToPtr :: forall a p. CMAC a -> Ptr p -> IO ()
copyByteArrayToPtr :: forall p. CMAC a -> Ptr p -> IO ()
ByteArrayAccess)

instance Eq (CMAC a) where
  CMAC Bytes
b1 == :: CMAC a -> CMAC a -> Bool
== CMAC Bytes
b2  =  Bytes -> Bytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
B.constEq Bytes
b1 Bytes
b2

-- | compute a MAC using the supplied cipher
cmac :: (ByteArrayAccess bin, BlockCipher cipher)
     => cipher      -- ^ key to compute CMAC with
     -> bin         -- ^ input message
     -> CMAC cipher -- ^ output tag
cmac :: forall bin cipher.
(ByteArrayAccess bin, BlockCipher cipher) =>
cipher -> bin -> CMAC cipher
cmac cipher
k bin
msg =
    Bytes -> CMAC cipher
forall a. Bytes -> CMAC a
CMAC (Bytes -> CMAC cipher) -> Bytes -> CMAC cipher
forall a b. (a -> b) -> a -> b
$ (Bytes -> Bytes -> Bytes) -> Bytes -> [Bytes] -> Bytes
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Bytes
c Bytes
m -> cipher -> Bytes -> Bytes
forall ba. ByteArray ba => cipher -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt cipher
k (Bytes -> Bytes) -> Bytes -> Bytes
forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes -> Bytes
forall ba. ByteArray ba => ba -> ba -> ba
bxor Bytes
c Bytes
m) Bytes
zeroV [Bytes]
ms
  where
    bytes :: Int
bytes = cipher -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize cipher
k
    zeroV :: Bytes
zeroV = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
bytes Word8
0 :: Bytes
    (Bytes
k1, Bytes
k2) = cipher -> (Bytes, Bytes)
forall k ba. (BlockCipher k, ByteArray ba) => k -> (ba, ba)
subKeys cipher
k
    ms :: [Bytes]
ms = cipher -> Bytes -> Bytes -> Bytes -> [Bytes]
forall k ba.
(BlockCipher k, ByteArray ba) =>
k -> ba -> ba -> ba -> [ba]
cmacChunks cipher
k Bytes
k1 Bytes
k2 (Bytes -> [Bytes]) -> Bytes -> [Bytes]
forall a b. (a -> b) -> a -> b
$ bin -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert bin
msg

cmacChunks :: (BlockCipher k, ByteArray ba) => k -> ba -> ba -> ba -> [ba]
cmacChunks :: forall k ba.
(BlockCipher k, ByteArray ba) =>
k -> ba -> ba -> ba -> [ba]
cmacChunks k
k ba
k1 ba
k2  =  ba -> [ba]
rec'  where
    rec' :: ba -> [ba]
rec' ba
msg
      | ba -> Bool
forall a. ByteArrayAccess a => a -> Bool
B.null ba
tl  =  if Int
lack Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                      then  [ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
bxor ba
k1 ba
hd]
                      else  [ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
bxor ba
k2 (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ ba
hd ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
`B.append` [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack (Word8
0x80 Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate (Int
lack Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Word8
0)]
      | Bool
otherwise  =        ba
hd ba -> [ba] -> [ba]
forall a. a -> [a] -> [a]
: ba -> [ba]
rec' ba
tl
      where
          bytes :: Int
bytes = k -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize k
k
          (ba
hd, ba
tl) = Int -> ba -> (ba, ba)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
bytes ba
msg
          lack :: Int
lack = Int
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
hd

-- | make sub-keys used in CMAC
subKeys :: (BlockCipher k, ByteArray ba)
        => k         -- ^ key to compute CMAC with
        -> (ba, ba)  -- ^ sub-keys to compute CMAC
subKeys :: forall k ba. (BlockCipher k, ByteArray ba) => k -> (ba, ba)
subKeys k
k = (ba
k1, ba
k2)   where
    ipt :: [Word8]
ipt = k -> [Word8]
forall k. BlockCipher k => k -> [Word8]
cipherIPT k
k
    k0 :: ba
k0 = k -> ba -> ba
forall ba. ByteArray ba => k -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt k
k (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ Int -> Word8 -> ba
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate (k -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize k
k) Word8
0
    k1 :: ba
k1 = [Word8] -> ba -> ba
forall ba. ByteArray ba => [Word8] -> ba -> ba
subKey [Word8]
ipt ba
k0
    k2 :: ba
k2 = [Word8] -> ba -> ba
forall ba. ByteArray ba => [Word8] -> ba -> ba
subKey [Word8]
ipt ba
k1

-- polynomial multiply operation to culculate subkey
subKey :: (ByteArray ba) => [Word8] -> ba -> ba
subKey :: forall ba. ByteArray ba => [Word8] -> ba -> ba
subKey [Word8]
ipt ba
ws  =  case ba -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack ba
ws of
    []                  ->  ba
forall a. ByteArray a => a
B.empty
    Word8
w:[Word8]
_  | Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word8
w Int
7  ->  [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack [Word8]
ipt ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
`bxor` ba -> ba
forall ba. ByteArray ba => ba -> ba
shiftL1 ba
ws
         | Bool
otherwise    ->  ba -> ba
forall ba. ByteArray ba => ba -> ba
shiftL1 ba
ws

shiftL1 :: (ByteArray ba) => ba -> ba
shiftL1 :: forall ba. ByteArray ba => ba -> ba
shiftL1 = [Word8] -> ba
forall a. ByteArray a => [Word8] -> a
B.pack ([Word8] -> ba) -> (ba -> [Word8]) -> ba -> ba
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> [Word8]
shiftL1W ([Word8] -> [Word8]) -> (ba -> [Word8]) -> ba -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ba -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
B.unpack

shiftL1W :: [Word8] -> [Word8]
shiftL1W :: [Word8] -> [Word8]
shiftL1W []         =  []
shiftL1W ws :: [Word8]
ws@(Word8
_:[Word8]
ns)  =  [(Word8, Word8)] -> [Word8]
forall {b} {a}. (Bits b, Bits a) => [(a, b)] -> [a]
rec' ([(Word8, Word8)] -> [Word8]) -> [(Word8, Word8)] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Word8] -> [Word8] -> [(Word8, Word8)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Word8]
ws ([Word8]
ns [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
0])   where
    rec' :: [(a, b)] -> [a]
rec'  []         =  []
    rec' ((a
x,b
y):[(a, b)]
ps)  =  a
w a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [(a, b)] -> [a]
rec' [(a, b)]
ps
      where
          w :: a
w | b -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit b
y Int
7  =  a -> Int -> a
forall a. Bits a => a -> Int -> a
setBit a
sl1 Int
0
            | Bool
otherwise    =  a
sl1
            where     sl1 :: a
sl1 = a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftL a
x Int
1

bxor :: ByteArray ba => ba -> ba -> ba
bxor :: forall ba. ByteArray ba => ba -> ba -> ba
bxor = ba -> ba -> ba
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor


-----


cipherIPT :: BlockCipher k => k -> [Word8]
cipherIPT :: forall k. BlockCipher k => k -> [Word8]
cipherIPT = Int -> [Word8]
expandIPT (Int -> [Word8]) -> (k -> Int) -> k -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize

-- Data type which represents the smallest irreducibule binary polynomial
-- against specified degree.
--
-- Maximum degree bit and degree 0 bit are omitted.
-- For example, The value /Q 7 2 1/ corresponds to the degree /128/.
-- It represents that the smallest irreducible binary polynomial of degree 128
-- is x^128 + x^7 + x^2 + x^1 + 1.
data IPolynomial
  = Q Int Int Int
---  | T Int

iPolynomial :: Int -> Maybe IPolynomial
iPolynomial :: Int -> Maybe IPolynomial
iPolynomial = Int -> Maybe IPolynomial
forall {a}. (Eq a, Num a) => a -> Maybe IPolynomial
d  where
    d :: a -> Maybe IPolynomial
d   a
64  =  IPolynomial -> Maybe IPolynomial
forall a. a -> Maybe a
Just (IPolynomial -> Maybe IPolynomial)
-> IPolynomial -> Maybe IPolynomial
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> IPolynomial
Q Int
4 Int
3 Int
1
    d  a
128  =  IPolynomial -> Maybe IPolynomial
forall a. a -> Maybe a
Just (IPolynomial -> Maybe IPolynomial)
-> IPolynomial -> Maybe IPolynomial
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> IPolynomial
Q Int
7 Int
2 Int
1
    d    a
_  =  Maybe IPolynomial
forall a. Maybe a
Nothing

-- Expand a tail bit pattern of irreducible binary polynomial
expandIPT :: Int -> [Word8]
expandIPT :: Int -> [Word8]
expandIPT Int
bytes = Int -> IPolynomial -> [Word8]
expandIPT' Int
bytes IPolynomial
ipt  where
    ipt :: IPolynomial
ipt = IPolynomial
-> (IPolynomial -> IPolynomial) -> Maybe IPolynomial -> IPolynomial
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> IPolynomial
forall a. HasCallStack => [Char] -> a
error ([Char] -> IPolynomial) -> [Char] -> IPolynomial
forall a b. (a -> b) -> a -> b
$ [Char]
"Irreducible binary polynomial not defined against " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
nb [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" bit") IPolynomial -> IPolynomial
forall a. a -> a
id
          (Maybe IPolynomial -> IPolynomial)
-> Maybe IPolynomial -> IPolynomial
forall a b. (a -> b) -> a -> b
$ Int -> Maybe IPolynomial
iPolynomial Int
nb
    nb :: Int
nb = Int
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8

-- Expand a tail bit pattern of irreducible binary polynomial
expandIPT' :: Int         -- ^ width in byte
           -> IPolynomial -- ^ irreducible binary polynomial definition
           -> [Word8]     -- ^ result bit pattern
expandIPT' :: Int -> IPolynomial -> [Word8]
expandIPT' Int
bytes (Q Int
x Int
y Int
z) =
    [Word8] -> [Word8]
forall a. [a] -> [a]
reverse ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
x ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
y ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
z ([Word8] -> [Word8]) -> ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall {a}. Bits a => Int -> [a] -> [a]
setB Int
0 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
bytes Word8
0
  where
    setB :: Int -> [a] -> [a]
setB Int
i [a]
ws =  [a]
hd [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> Int -> a
forall a. Bits a => a -> Int -> a
setBit ([a] -> a
forall a. HasCallStack => [a] -> a
head [a]
tl) Int
r a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a]
forall a. HasCallStack => [a] -> [a]
tail [a]
tl  where
        (Int
q, Int
r) = Int
i Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
8
        ([a]
hd, [a]
tl) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q [a]
ws