{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.Packet.Decrypt (
    decryptCrypt,
) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import Foreign.Ptr
import Network.ByteOrder

import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Imports
import Network.QUIC.Packet.Frame
import Network.QUIC.Packet.Header
import Network.QUIC.Packet.Number
import Network.QUIC.Types

----------------------------------------------------------------

decryptCrypt :: Connection -> Crypt -> EncryptionLevel -> IO (Maybe Plain)
decryptCrypt :: Connection -> Crypt -> EncryptionLevel -> IO (Maybe Plain)
decryptCrypt Connection
conn Crypt{PacketNumber
Maybe MigrationInfo
ByteString
cryptMigraionInfo :: Crypt -> Maybe MigrationInfo
cryptMarks :: Crypt -> PacketNumber
cryptPacket :: Crypt -> ByteString
cryptPktNumOffset :: Crypt -> PacketNumber
cryptMigraionInfo :: Maybe MigrationInfo
cryptMarks :: PacketNumber
cryptPacket :: ByteString
cryptPktNumOffset :: PacketNumber
..} EncryptionLevel
lvl = do
    Cipher
cipher <- Connection -> EncryptionLevel -> IO Cipher
getCipher Connection
conn EncryptionLevel
lvl
    Protector
protector <- Connection -> EncryptionLevel -> IO Protector
getProtector Connection
conn EncryptionLevel
lvl
    let proFlags :: Flags a
proFlags = forall a. Word8 -> Flags a
Flags (ByteString
cryptPacket HasCallStack => ByteString -> PacketNumber -> Word8
`BS.index` PacketNumber
0)
        sampleOffset :: PacketNumber
sampleOffset = PacketNumber
cryptPktNumOffset forall a. Num a => a -> a -> a
+ PacketNumber
4
        sampleLen :: PacketNumber
sampleLen = Cipher -> PacketNumber
sampleLength Cipher
cipher
        sample :: Sample
sample = ByteString -> Sample
Sample forall a b. (a -> b) -> a -> b
$ PacketNumber -> ByteString -> ByteString
BS.take PacketNumber
sampleLen forall a b. (a -> b) -> a -> b
$ PacketNumber -> ByteString -> ByteString
BS.drop PacketNumber
sampleOffset ByteString
cryptPacket
        makeMask :: Sample -> Mask
makeMask = Protector -> Sample -> Mask
unprotect Protector
protector
        Mask ByteString
mask = Sample -> Mask
makeMask Sample
sample
    case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
mask of
        Maybe (Word8, ByteString)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        Just (Word8
mask1, ByteString
mask2) -> do
            let rawFlags :: Flags Raw
rawFlags@(Flags Word8
flags) = Flags Protected -> Word8 -> Flags Raw
unprotectFlags forall {a}. Flags a
proFlags Word8
mask1
                epnLen :: PacketNumber
epnLen = Flags Raw -> PacketNumber
decodePktNumLength Flags Raw
rawFlags
                epn :: ByteString
epn = PacketNumber -> ByteString -> ByteString
BS.take PacketNumber
epnLen forall a b. (a -> b) -> a -> b
$ PacketNumber -> ByteString -> ByteString
BS.drop PacketNumber
cryptPktNumOffset ByteString
cryptPacket
                bytePN :: ByteString
bytePN = ByteString -> ByteString -> ByteString
bsXOR ByteString
mask2 ByteString
epn
                headerLen :: PacketNumber
headerLen = PacketNumber
cryptPktNumOffset forall a. Num a => a -> a -> a
+ PacketNumber
epnLen
                (ByteString
proHeader, ByteString
ciphertext) = PacketNumber -> ByteString -> (ByteString, ByteString)
BS.splitAt PacketNumber
headerLen ByteString
cryptPacket
            PacketNumber
peerPN <- if EncryptionLevel
lvl forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level then Connection -> IO PacketNumber
getPeerPacketNumber Connection
conn else forall (m :: * -> *) a. Monad m => a -> m a
return PacketNumber
0
            let pn :: PacketNumber
pn = PacketNumber -> EncodedPacketNumber -> PacketNumber -> PacketNumber
decodePacketNumber PacketNumber
peerPN (ByteString -> EncodedPacketNumber
toEncodedPacketNumber ByteString
bytePN) PacketNumber
epnLen
            ByteString
header <- PacketNumber -> (Ptr Word8 -> IO ()) -> IO ByteString
BS.create PacketNumber
headerLen forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> do
                forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> ByteString -> IO (Ptr Word8)
copy Ptr Word8
p ByteString
proHeader
                Word8 -> Ptr Word8 -> PacketNumber -> IO ()
poke8 Word8
flags Ptr Word8
p PacketNumber
0
                forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> ByteString -> IO (Ptr Word8)
copy (Ptr Word8
p forall a b. Ptr a -> PacketNumber -> Ptr b
`plusPtr` PacketNumber
cryptPktNumOffset) forall a b. (a -> b) -> a -> b
$ PacketNumber -> ByteString -> ByteString
BS.take PacketNumber
epnLen ByteString
bytePN
            let keyPhase :: Bool
keyPhase
                    | EncryptionLevel
lvl forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level = Word8
flags forall a. Bits a => a -> PacketNumber -> Bool
`testBit` PacketNumber
2
                    | Bool
otherwise = Bool
False
            Coder
coder <- Connection -> EncryptionLevel -> Bool -> IO Coder
getCoder Connection
conn EncryptionLevel
lvl Bool
keyPhase
            PacketNumber
siz <- Coder
-> Ptr Word8
-> ByteString
-> AssDat
-> PacketNumber
-> IO PacketNumber
decrypt Coder
coder (Connection -> Ptr Word8
decryptBuf Connection
conn) ByteString
ciphertext (ByteString -> AssDat
AssDat ByteString
header) PacketNumber
pn
            let rrMask :: Word8
rrMask
                    | EncryptionLevel
lvl forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level = Word8
0x18
                    | Bool
otherwise = Word8
0x0c
                marks :: PacketNumber
marks
                    | Word8
flags forall a. Bits a => a -> a -> a
.&. Word8
rrMask forall a. Eq a => a -> a -> Bool
== Word8
0 = PacketNumber
defaultPlainMarks
                    | Bool
otherwise = PacketNumber -> PacketNumber
setIllegalReservedBits PacketNumber
defaultPlainMarks
            if PacketNumber
siz forall a. Ord a => a -> a -> Bool
< PacketNumber
0
                then forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                else do
                    Maybe [Frame]
mframes <- Ptr Word8 -> PacketNumber -> IO (Maybe [Frame])
decodeFramesBuffer (Connection -> Ptr Word8
decryptBuf Connection
conn) PacketNumber
siz
                    case Maybe [Frame]
mframes of
                        Maybe [Frame]
Nothing -> do
                            let marks' :: PacketNumber
marks' = PacketNumber -> PacketNumber
setUnknownFrame PacketNumber
marks
                            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Flags Raw -> PacketNumber -> [Frame] -> PacketNumber -> Plain
Plain Flags Raw
rawFlags PacketNumber
pn [] PacketNumber
marks'
                        Just [Frame]
frames -> do
                            let marks' :: PacketNumber
marks'
                                    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Frame]
frames = PacketNumber -> PacketNumber
setNoFrames PacketNumber
marks
                                    | Bool
otherwise = PacketNumber
marks
                            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Flags Raw -> PacketNumber -> [Frame] -> PacketNumber -> Plain
Plain Flags Raw
rawFlags PacketNumber
pn [Frame]
frames PacketNumber
marks'

toEncodedPacketNumber :: ByteString -> EncodedPacketNumber
toEncodedPacketNumber :: ByteString -> EncodedPacketNumber
toEncodedPacketNumber ByteString
bs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\EncodedPacketNumber
b Word8
a -> EncodedPacketNumber
b forall a. Num a => a -> a -> a
* EncodedPacketNumber
256 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
a) EncodedPacketNumber
0 forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
BS.unpack ByteString
bs