module Network.TLS.MAC (
    macSSL,
    hmac,
    prf_MD5,
    prf_SHA1,
    prf_SHA256,
    prf_TLS,
    prf_MD5SHA1,
) where

import qualified Data.ByteArray as B (xor)
import qualified Data.ByteString as B
import Network.TLS.Crypto
import Network.TLS.Imports
import Network.TLS.Types

type HMAC = ByteString -> ByteString -> ByteString

macSSL :: Hash -> HMAC
macSSL :: Hash -> HMAC
macSSL Hash
alg ByteString
secret ByteString
msg =
    ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
        [ByteString] -> ByteString
B.concat
            [ ByteString
secret
            , Int -> Word8 -> ByteString
B.replicate Int
padLen Word8
0x5c
            , ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat [ByteString
secret, Int -> Word8 -> ByteString
B.replicate Int
padLen Word8
0x36, ByteString
msg]
            ]
  where
    padLen :: Int
padLen = case Hash
alg of
        Hash
MD5 -> Int
48
        Hash
SHA1 -> Int
40
        Hash
_ -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char]
"internal error: macSSL called with " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Hash -> [Char]
forall a. Show a => a -> [Char]
show Hash
alg)
    f :: ByteString -> ByteString
f = Hash -> ByteString -> ByteString
hash Hash
alg

hmac :: Hash -> HMAC
hmac :: Hash -> HMAC
hmac Hash
alg ByteString
secret ByteString
msg = ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
B.append ByteString
opad (ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
B.append ByteString
ipad ByteString
msg)
  where
    opad :: ByteString
opad = (Word8 -> Word8) -> ByteString -> ByteString
B.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
0x5c) ByteString
k'
    ipad :: ByteString
ipad = (Word8 -> Word8) -> ByteString -> ByteString
B.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
0x36) ByteString
k'

    f :: ByteString -> ByteString
f = Hash -> ByteString -> ByteString
hash Hash
alg
    bl :: Int
bl = Hash -> Int
hashBlockSize Hash
alg

    k' :: ByteString
k' = HMAC
B.append ByteString
kt ByteString
pad
      where
        kt :: ByteString
kt = if ByteString -> Int
B.length ByteString
secret Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl then ByteString -> ByteString
f ByteString
secret else ByteString
secret
        pad :: ByteString
pad = Int -> Word8 -> ByteString
B.replicate (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
kt) Word8
0

hmacIter
    :: HMAC -> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter :: HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter HMAC
f ByteString
secret ByteString
seed ByteString
aprev Int
len =
    let an :: ByteString
an = HMAC
f ByteString
secret ByteString
aprev
     in let out :: ByteString
out = HMAC
f ByteString
secret ([ByteString] -> ByteString
B.concat [ByteString
an, ByteString
seed])
         in let digestsize :: Int
digestsize = ByteString -> Int
B.length ByteString
out
             in if Int
digestsize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
                    then [Int -> ByteString -> ByteString
B.take (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) ByteString
out]
                    else ByteString
out ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter HMAC
f ByteString
secret ByteString
seed ByteString
an (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
digestsize)

prf_SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA1 ByteString
secret ByteString
seed Int
len = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
SHA1) ByteString
secret ByteString
seed ByteString
seed Int
len

prf_MD5 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5 ByteString
secret ByteString
seed Int
len = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
MD5) ByteString
secret ByteString
seed ByteString
seed Int
len

prf_MD5SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5SHA1 ByteString
secret ByteString
seed Int
len =
    HMAC
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor (ByteString -> ByteString -> Int -> ByteString
prf_MD5 ByteString
s1 ByteString
seed Int
len) (ByteString -> ByteString -> Int -> ByteString
prf_SHA1 ByteString
s2 ByteString
seed Int
len)
  where
    slen :: Int
slen = ByteString -> Int
B.length ByteString
secret
    s1 :: ByteString
s1 = Int -> ByteString -> ByteString
B.take (Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
2) ByteString
secret
    s2 :: ByteString
s2 = Int -> ByteString -> ByteString
B.drop (Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ByteString
secret

prf_SHA256 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA256 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA256 ByteString
secret ByteString
seed Int
len = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
SHA256) ByteString
secret ByteString
seed ByteString
seed Int
len

-- | For now we ignore the version, but perhaps some day the PRF will depend
-- not only on the cipher PRF algorithm, but also on the protocol version.
prf_TLS :: Version -> Hash -> ByteString -> ByteString -> Int -> ByteString
prf_TLS :: Version -> Hash -> ByteString -> ByteString -> Int -> ByteString
prf_TLS Version
_ Hash
halg ByteString
secret ByteString
seed Int
len =
    [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
halg) ByteString
secret ByteString
seed ByteString
seed Int
len