{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.QUIC.Crypto.Nite ( niteEncrypt, niteEncrypt', niteDecrypt, niteDecrypt', protectionMask, aes128gcmEncrypt, makeNonce, makeNiteEncrypt, makeNiteDecrypt, makeNiteProtector, ) where import Crypto.Cipher.AES import Crypto.Cipher.Types hiding (Cipher, IV) import Crypto.Error (maybeCryptoError) import qualified Data.ByteArray as Byte (convert) import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BS import Foreign.ForeignPtr (newForeignPtr_, withForeignPtr) import Foreign.Marshal.Alloc (mallocBytes) import Foreign.Marshal.Utils (copyBytes) import Foreign.Ptr (Ptr, nullPtr, plusPtr) import Foreign.Storable (peek, poke) import Network.TLS hiding (Version) import Network.TLS.Extra.Cipher import Network.QUIC.Crypto.Types import Network.QUIC.Imports import Network.QUIC.Types ---------------------------------------------------------------- -- It would be nice to take [PlainText] and update AEAD context with -- [PlainText]. But since each PlainText is not aligned to cipher block, -- it's impossible. cipherEncrypt :: Cipher -> Key -> Nonce -> PlainText -> AssDat -> Maybe (CipherText, CipherText) cipherEncrypt cipher | cipher == cipher_TLS13_AES128GCM_SHA256 = aes128gcmEncrypt | cipher == cipher_TLS13_AES128CCM_SHA256 = error "cipher_TLS13_AES128CCM_SHA256" | cipher == cipher_TLS13_AES256GCM_SHA384 = aes256gcmEncrypt | otherwise = error "cipherEncrypt" cipherDecrypt :: Cipher -> Key -> Nonce -> CipherText -> AssDat -> Maybe PlainText cipherDecrypt cipher | cipher == cipher_TLS13_AES128GCM_SHA256 = aes128gcmDecrypt | cipher == cipher_TLS13_AES128CCM_SHA256 = error "cipher_TLS13_AES128CCM_SHA256" | cipher == cipher_TLS13_AES256GCM_SHA384 = aes256gcmDecrypt | otherwise = error "cipherDecrypt" -- IMPORTANT: Using 'let' so that parameters can be memorized. aes128gcmEncrypt :: Key -> (Nonce -> PlainText -> AssDat -> Maybe (CipherText, CipherText)) aes128gcmEncrypt (Key key) = case maybeCryptoError $ cipherInit key of Nothing -> \_ _ _ -> Nothing Just (aes :: AES128) -> \(Nonce nonce) plaintext (AssDat ad) -> case maybeCryptoError $ aeadInit AEAD_GCM aes nonce of Nothing -> Nothing Just aead -> let (AuthTag tag0, ciphertext) = aeadSimpleEncrypt aead ad plaintext 16 tag = Byte.convert tag0 in Just (ciphertext, tag) aes128gcmDecrypt :: Key -> (Nonce -> CipherText -> AssDat -> Maybe PlainText) aes128gcmDecrypt (Key key) = case maybeCryptoError $ cipherInit key of Nothing -> \_ _ _ -> Nothing Just (aes :: AES128) -> \(Nonce nonce) ciphertag (AssDat ad) -> case maybeCryptoError $ aeadInit AEAD_GCM aes nonce of Nothing -> Nothing Just aead -> let (ciphertext, tag) = BS.splitAt (BS.length ciphertag - 16) ciphertag authtag = AuthTag $ Byte.convert tag in aeadSimpleDecrypt aead ad ciphertext authtag aes256gcmEncrypt :: Key -> (Nonce -> PlainText -> AssDat -> Maybe (CipherText, CipherText)) aes256gcmEncrypt (Key key) = case maybeCryptoError $ cipherInit key of Nothing -> \_ _ _ -> Nothing Just (aes :: AES256) -> \(Nonce nonce) plaintext (AssDat ad) -> case maybeCryptoError $ aeadInit AEAD_GCM aes nonce of Nothing -> Nothing Just aead -> let (AuthTag tag0, ciphertext) = aeadSimpleEncrypt aead ad plaintext 16 tag = Byte.convert tag0 in Just (ciphertext, tag) aes256gcmDecrypt :: Key -> (Nonce -> CipherText -> AssDat -> Maybe PlainText) aes256gcmDecrypt (Key key) = case maybeCryptoError $ cipherInit key of Nothing -> \_ _ _ -> Nothing Just (aes :: AES256) -> \(Nonce nonce) ciphertag (AssDat ad) -> case maybeCryptoError $ aeadInit AEAD_GCM aes nonce of Nothing -> Nothing Just aead -> let (ciphertext, tag) = BS.splitAt (BS.length ciphertag - 16) ciphertag authtag = AuthTag $ Byte.convert tag in aeadSimpleDecrypt aead ad ciphertext authtag ---------------------------------------------------------------- makeNonce :: IV -> ByteString -> Nonce makeNonce (IV iv) pn = Nonce nonce where nonce = bsXORpad iv pn -- XORing IV and a packet numbr with left padded. -- src0 -- IV +IIIIIIIIIIIIIIIIII--------+ -- diff src1 -- PN +000000000000000000+-------+ -- dst -- Nonce +IIIIIIIIIIIIIIIIII--------+ bsXORpad :: ByteString -> ByteString -> ByteString bsXORpad (PS fp0 off0 len0) (PS fp1 off1 len1) | len0 < len1 = error "bsXORpad" | otherwise = BS.unsafeCreate len0 $ \dst -> withForeignPtr fp0 $ \p0 -> withForeignPtr fp1 $ \p1 -> do let src0 = p0 `plusPtr` off0 let src1 = p1 `plusPtr` off1 let diff = len0 - len1 copyBytes dst src0 diff loop (dst `plusPtr` diff) (src0 `plusPtr` diff) src1 len1 where loop :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO () loop _ _ _ 0 = return () loop dst src0 src1 len = do w1 <- peek src0 w2 <- peek src1 poke dst (w1 `xor` w2) loop (dst `plusPtr` 1) (src0 `plusPtr` 1) (src1 `plusPtr` 1) (len - 1) {- bsXORpad' :: ByteString -> ByteString -> ByteString bsXORpad' iv pn = BS.pack $ zipWith xor ivl pnl where ivl = BS.unpack iv diff = BS.length iv - BS.length pn pnl = replicate diff 0 ++ BS.unpack pn -} ---------------------------------------------------------------- type NiteEncrypt = Buffer -> PlainText -> AssDat -> PacketNumber -> IO Int makeNiteEncrypt :: Cipher -> Key -> IV -> NiteEncrypt makeNiteEncrypt cipher key iv = niteEncryptWrapper (niteEncrypt cipher key iv) niteEncryptWrapper :: (PlainText -> AssDat -> PacketNumber -> Maybe (CipherText, CipherText)) -> NiteEncrypt niteEncryptWrapper enc dst plaintext ad pn = case enc plaintext ad pn of Nothing -> return (-1) Just (hdr, bdy) -> do len <- copyBS dst hdr let dst' = dst `plusPtr` len len' <- copyBS dst' bdy return (len + len') niteEncrypt :: Cipher -> Key -> IV -> PlainText -> AssDat -> PacketNumber -> Maybe (CipherText, CipherText) niteEncrypt cipher key iv = let enc = cipherEncrypt cipher key mk = makeNonce iv in \plaintext header pn -> let bytePN = bytestring64 $ fromIntegral pn nonce = mk bytePN in enc nonce plaintext header niteEncrypt' :: Cipher -> Key -> Nonce -> PlainText -> AssDat -> Maybe (CipherText, CipherText) niteEncrypt' cipher key nonce plaintext header = cipherEncrypt cipher key nonce plaintext header ---------------------------------------------------------------- type NiteDecrypt = Buffer -> CipherText -> AssDat -> PacketNumber -> IO Int makeNiteDecrypt :: Cipher -> Key -> IV -> NiteDecrypt makeNiteDecrypt cipher key iv = niteDecryptWrapper (niteDecrypt cipher key iv) niteDecryptWrapper :: (CipherText -> AssDat -> PacketNumber -> Maybe PlainText) -> NiteDecrypt niteDecryptWrapper dec dst ciphertext ad pn = case dec ciphertext ad pn of Nothing -> return (-1) Just bs -> copyBS dst bs niteDecrypt :: Cipher -> Key -> IV -> CipherText -> AssDat -> PacketNumber -> Maybe PlainText niteDecrypt cipher key iv = let dec = cipherDecrypt cipher key mk = makeNonce iv in \ciphertext header pn -> let bytePN = bytestring64 (fromIntegral pn) nonce = mk bytePN in dec nonce ciphertext header niteDecrypt' :: Cipher -> Key -> Nonce -> CipherText -> AssDat -> Maybe PlainText niteDecrypt' cipher key nonce ciphertext header = cipherDecrypt cipher key nonce ciphertext header ---------------------------------------------------------------- protectionMask :: Cipher -> Key -> (Sample -> Mask) protectionMask cipher key = let f = cipherHeaderProtection cipher key in \sample -> f sample cipherHeaderProtection :: Cipher -> Key -> (Sample -> Mask) cipherHeaderProtection cipher key | cipher == cipher_TLS13_AES128GCM_SHA256 = aes128ecbEncrypt key | cipher == cipher_TLS13_AES128CCM_SHA256 = error "cipher_TLS13_AES128CCM_SHA256" | cipher == cipher_TLS13_AES256GCM_SHA384 = aes256ecbEncrypt key | otherwise = error "cipherHeaderProtection" aes128ecbEncrypt :: Key -> (Sample -> Mask) aes128ecbEncrypt (Key key) = case maybeCryptoError $ cipherInit key of Nothing -> \_ -> Mask "0123456789012345" Just (aes :: AES128) -> let encrypt = ecbEncrypt aes in \(Sample sample) -> let mask = encrypt sample in Mask mask aes256ecbEncrypt :: Key -> (Sample -> Mask) aes256ecbEncrypt (Key key) = case maybeCryptoError $ cipherInit key of Nothing -> \_ -> Mask "0123456789012345" Just (aes :: AES256) -> let encrypt = ecbEncrypt aes in \(Sample sample) -> let mask = encrypt sample in Mask mask ---------------------------------------------------------------- makeNiteProtector :: Cipher -> Key -> IO (Buffer -> IO (), IO Buffer) makeNiteProtector cipher key = do ref <- newIORef nullPtr dstbuf <- mallocBytes 32 -- fixme: free return (niteSetSample ref, niteGetMask ref samplelen mkMask dstbuf) where samplelen = 16 -- sampleLength cipher -- fixme mkMask = protectionMask cipher key niteSetSample :: IORef Buffer -> Buffer -> IO () niteSetSample = writeIORef niteGetMask :: IORef Buffer -> Int -> (Sample -> Mask) -> Buffer -> IO Buffer niteGetMask ref samplelen mkMask dstbuf = do srcbuf <- readIORef ref sample <- do fptr <- newForeignPtr_ srcbuf return $ PS fptr 0 samplelen let Mask mask = mkMask $ Sample sample _len <- copyBS dstbuf mask return dstbuf