{-# LANGUAGE BinaryLiterals #-}

module Network.QPACK.HeaderBlock.Prefix (
    -- * Prefix
    encodePrefix,
    decodePrefix,
    encodeRequiredInsertCount,
    decodeRequiredInsertCount,
    encodeBase,
    decodeBase,
) where

import Network.ByteOrder
import Network.HPACK.Internal (decodeI, encodeI)
import qualified UnliftIO.Exception as E

import Imports
import Network.QPACK.Error
import Network.QPACK.Table
import Network.QPACK.Types

----------------------------------------------------------------

-- |
-- >>> encodeRequiredInsertCount 3 9
-- 4
-- >>> encodeRequiredInsertCount 128 1000
-- 233
encodeRequiredInsertCount :: Int -> InsertionPoint -> Int
encodeRequiredInsertCount :: Int -> InsertionPoint -> Int
encodeRequiredInsertCount Int
_ InsertionPoint
0 = Int
0
encodeRequiredInsertCount Int
maxEntries (InsertionPoint Int
reqInsertCount) =
    (Int
reqInsertCount Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxEntries)) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

-- | for decoder
--
-- >>> decodeRequiredInsertCount 3 10 4
-- InsertionPoint 9
-- >>> decodeRequiredInsertCount 128 990 233
-- InsertionPoint 1000
decodeRequiredInsertCount :: Int -> InsertionPoint -> Int -> InsertionPoint
decodeRequiredInsertCount :: Int -> InsertionPoint -> Int -> InsertionPoint
decodeRequiredInsertCount Int
_ InsertionPoint
_ Int
0 = InsertionPoint
0
decodeRequiredInsertCount Int
maxEntries (InsertionPoint Int
totalNumberOfInserts) Int
encodedInsertCount
    | Int
encodedInsertCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
fullRange = DecodeError -> InsertionPoint
forall e a. Exception e => e -> a
E.impureThrow DecodeError
IllegalInsertCount
    | Int
reqInsertCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxValue Bool -> Bool -> Bool
&& Int
reqInsertCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
fullRange =
        DecodeError -> InsertionPoint
forall e a. Exception e => e -> a
E.impureThrow DecodeError
IllegalInsertCount
    | Int
reqInsertCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxValue = Int -> InsertionPoint
InsertionPoint (Int
reqInsertCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fullRange)
    | Bool
otherwise = Int -> InsertionPoint
InsertionPoint Int
reqInsertCount
  where
    fullRange :: Int
fullRange = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
maxEntries
    maxValue :: Int
maxValue = Int
totalNumberOfInserts Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
maxEntries
    maxWrapped :: Int
maxWrapped = (Int
maxValue Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
fullRange) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
fullRange
    reqInsertCount :: Int
reqInsertCount = Int
maxWrapped Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
encodedInsertCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

----------------------------------------------------------------

-- |
-- >>> encodeBase 6 9
-- (False,3)
-- >>> encodeBase 9 6
-- (True,2)
encodeBase :: InsertionPoint -> BasePoint -> (Bool, Int)
encodeBase :: InsertionPoint -> BasePoint -> (Bool, Int)
encodeBase (InsertionPoint Int
reqInsCnt) (BasePoint Int
base)
    | Int
diff Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = (Bool
False, Int
diff) -- base - reqInsCnt
    | Bool
otherwise = (Bool
True, Int -> Int
forall a. Num a => a -> a
negate Int
diff Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) -- reqInsCnt - base - 1
  where
    diff :: Int
diff = Int
base Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
reqInsCnt

-- |
-- >>> decodeBase 6 False 3
-- BasePoint 9
-- >>> decodeBase 9 True 2
-- BasePoint 6
decodeBase :: InsertionPoint -> Bool -> Int -> BasePoint
decodeBase :: InsertionPoint -> Bool -> Int -> BasePoint
decodeBase (InsertionPoint Int
reqInsCnt) Bool
False Int
deltaBase = Int -> BasePoint
BasePoint (Int
reqInsCnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
deltaBase)
decodeBase (InsertionPoint Int
reqInsCnt) Bool
True Int
deltaBase = Int -> BasePoint
BasePoint (Int
reqInsCnt Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
deltaBase Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

----------------------------------------------------------------

-- | Encoding the prefix part of header block.
--   This should be used after 'encodeTokenHeader'.
encodePrefix :: WriteBuffer -> DynamicTable -> IO ()
encodePrefix :: WriteBuffer -> DynamicTable -> IO ()
encodePrefix WriteBuffer
wbuf DynamicTable
dyntbl = do
    WriteBuffer -> IO ()
clearWriteBuffer WriteBuffer
wbuf
    Int
maxEntries <- DynamicTable -> IO Int
getMaxNumOfEntries DynamicTable
dyntbl
    BasePoint
baseIndex <- DynamicTable -> IO BasePoint
getBasePoint DynamicTable
dyntbl
    InsertionPoint
reqInsCnt <- DynamicTable -> IO InsertionPoint
getLargestReference DynamicTable
dyntbl
    -- Required Insert Count
    let ric :: Int
ric = Int -> InsertionPoint -> Int
encodeRequiredInsertCount Int
maxEntries InsertionPoint
reqInsCnt
    WriteBuffer -> (Word8 -> Word8) -> Int -> Int -> IO ()
encodeI WriteBuffer
wbuf Word8 -> Word8
set0 Int
8 Int
ric
    -- Sign bit + Delta Base (7+)
    let (Bool
s, Int
base) = InsertionPoint -> BasePoint -> (Bool, Int)
encodeBase InsertionPoint
reqInsCnt BasePoint
baseIndex
        set :: Word8 -> Word8
set
            | Bool
s = Word8 -> Word8
set1
            | Bool
otherwise = Word8 -> Word8
set0
    WriteBuffer -> (Word8 -> Word8) -> Int -> Int -> IO ()
encodeI WriteBuffer
wbuf Word8 -> Word8
set Int
7 Int
base

-- | Decoding the prefix part of header block.
decodePrefix :: ReadBuffer -> DynamicTable -> IO (InsertionPoint, BasePoint)
decodePrefix :: ReadBuffer -> DynamicTable -> IO (InsertionPoint, BasePoint)
decodePrefix ReadBuffer
rbuf DynamicTable
dyntbl = do
    Int
maxEntries <- DynamicTable -> IO Int
getMaxNumOfEntries DynamicTable
dyntbl
    InsertionPoint
totalNumberOfInserts <- DynamicTable -> IO InsertionPoint
getInsertionPoint DynamicTable
dyntbl
    Word8
w8 <- ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    Int
ric <- Int -> Word8 -> ReadBuffer -> IO Int
decodeI Int
8 Word8
w8 ReadBuffer
rbuf
    let reqInsCnt :: InsertionPoint
reqInsCnt = Int -> InsertionPoint -> Int -> InsertionPoint
decodeRequiredInsertCount Int
maxEntries InsertionPoint
totalNumberOfInserts Int
ric
    Word8
w8' <- ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    let s :: Bool
s = Word8
w8' Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
7
        w8'' :: Word8
w8'' = Word8
w8' Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b01111111
    Int
delta <- Int -> Word8 -> ReadBuffer -> IO Int
decodeI Int
7 Word8
w8'' ReadBuffer
rbuf
    let baseIndex :: BasePoint
baseIndex = InsertionPoint -> Bool -> Int -> BasePoint
decodeBase InsertionPoint
reqInsCnt Bool
s Int
delta
    DynamicTable -> IO () -> IO ()
qpackDebug DynamicTable
dyntbl (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
            String
"Required" String -> String -> String
forall a. [a] -> [a] -> [a]
++ InsertionPoint -> String
forall a. Show a => a -> String
show InsertionPoint
reqInsCnt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ BasePoint -> String
forall a. Show a => a -> String
show BasePoint
baseIndex
    (InsertionPoint, BasePoint) -> IO (InsertionPoint, BasePoint)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (InsertionPoint
reqInsCnt, BasePoint
baseIndex)