{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Crypto.Utils (
    tagLength
  , sampleLength
  , bsXOR
  , calculateIntegrityTag
  ) where

import qualified Data.ByteArray as Byte (xor)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as Short
import Network.TLS hiding (Version)
import Network.TLS.Extra.Cipher

import Network.QUIC.Crypto.Nite
import Network.QUIC.Crypto.Types
import Network.QUIC.Imports
import Network.QUIC.Types

----------------------------------------------------------------

bsXOR :: ByteString -> ByteString -> ByteString
bsXOR :: PlainText -> PlainText -> PlainText
bsXOR = forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
Byte.xor

----------------------------------------------------------------

tagLength :: Cipher -> Int
tagLength :: Cipher -> Int
tagLength Cipher
cipher
  | Cipher
cipher forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Int
16
  | Cipher
cipher forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = Int
16
  | Cipher
cipher forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Int
16
  | Bool
otherwise                                      = forall a. HasCallStack => [Char] -> a
error [Char]
"tagLength"

sampleLength :: Cipher -> Int
sampleLength :: Cipher -> Int
sampleLength Cipher
cipher
  | Cipher
cipher forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Int
16
  | Cipher
cipher forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = Int
16
  | Cipher
cipher forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Int
16
  | Bool
otherwise                                      = forall a. HasCallStack => [Char] -> a
error [Char]
"sampleLength"

----------------------------------------------------------------

calculateIntegrityTag :: Version -> CID -> ByteString -> ByteString
calculateIntegrityTag :: Version -> CID -> PlainText -> PlainText
calculateIntegrityTag Version
ver CID
oCID PlainText
pseudo0 =
    case Key -> Nonce -> PlainText -> AssDat -> Maybe (PlainText, PlainText)
aes128gcmEncrypt (Version -> Key
key Version
ver) (Version -> Nonce
nonce Version
ver) PlainText
"" (PlainText -> AssDat
AssDat PlainText
pseudo) of
      Maybe (PlainText, PlainText)
Nothing -> PlainText
""
      Just (PlainText
hdr,PlainText
bdy) -> PlainText
hdr PlainText -> PlainText -> PlainText
`BS.append` PlainText
bdy
  where
    (ShortByteString
ocid, Word8
ocidlen) = CID -> (ShortByteString, Word8)
unpackCID CID
oCID
    pseudo :: PlainText
pseudo = [PlainText] -> PlainText
BS.concat [Word8 -> PlainText
BS.singleton Word8
ocidlen
                       ,ShortByteString -> PlainText
Short.fromShort ShortByteString
ocid
                       ,PlainText
pseudo0]
    key :: Version -> Key
key Version
Draft29  = PlainText -> Key
Key PlainText
"\xcc\xce\x18\x7e\xd0\x9a\x09\xd0\x57\x28\x15\x5a\x6c\xb9\x6b\xe1"
    key Version
Version1 = PlainText -> Key
Key PlainText
"\xbe\x0c\x69\x0b\x9f\x66\x57\x5a\x1d\x76\x6b\x54\xe3\x68\xc8\x4e"
    key Version
Version2 = PlainText -> Key
Key PlainText
"\x8f\xb4\xb0\x1b\x56\xac\x48\xe2\x60\xfb\xcb\xce\xad\x7c\xcc\x92"
    key Version
_        = PlainText -> Key
Key PlainText
"not supported"
    nonce :: Version -> Nonce
nonce Version
Draft29  = PlainText -> Nonce
Nonce PlainText
"\xe5\x49\x30\xf9\x7f\x21\x36\xf0\x53\x0a\x8c\x1c"
    nonce Version
Version1 = PlainText -> Nonce
Nonce PlainText
"\x46\x15\x99\xd3\x5d\x63\x2b\xf2\x23\x98\x25\xbb"
    nonce Version
Version2 = PlainText -> Nonce
Nonce PlainText
"\xd8\x69\x69\xbc\x2d\x7c\x6d\x99\x90\xef\xb0\x4a"
    nonce Version
_        = PlainText -> Nonce
Nonce PlainText
"not supported"