{-# LANGUAGE RecordWildCards #-}

module Network.HPACK.Huffman.Encode (
    -- * Huffman encoding
    encodeH,
    encodeHuffman,
) where

import Data.Array.Base (unsafeAt)
import Data.Array.IArray (listArray)
import Data.Array.Unboxed (UArray)
import Data.IORef
import Foreign.Ptr (minusPtr, plusPtr)
import Foreign.Storable (poke)
import Network.ByteOrder hiding (copy)
import UnliftIO.Exception (throwIO)

import Imports
import Network.HPACK.Huffman.Params (idxEos)
import Network.HPACK.Huffman.Table

----------------------------------------------------------------

huffmanLength :: UArray Int Int
huffmanLength :: UArray Int Int
huffmanLength = (Int, Int) -> [Int] -> UArray Int Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Int
0, Int
idxEos) ([Int] -> UArray Int Int) -> [Int] -> UArray Int Int
forall a b. (a -> b) -> a -> b
$ ([B] -> Int) -> [[B]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [B] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[B]]
huffmanTable

huffmanCode :: UArray Int Word64
huffmanCode :: UArray Int Word64
huffmanCode = (Int, Int) -> [Word64] -> UArray Int Word64
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Int
0, Int
idxEos) [Word64]
huffmanTable'

----------------------------------------------------------------

-- | Huffman encoding.
encodeH
    :: WriteBuffer
    -> ByteString
    -- ^ Target
    -> IO Int
    -- ^ The length of the encoded string.
encodeH :: WriteBuffer -> ByteString -> IO Int
encodeH WriteBuffer
dst ByteString
bs = ByteString -> (ReadBuffer -> IO Int) -> IO Int
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO Int) -> IO Int)
-> (ReadBuffer -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ WriteBuffer -> ReadBuffer -> IO Int
enc WriteBuffer
dst

-- The maximum length of Huffman code is 30.
-- 40 is enough as a work space.
initialOffset :: Int
initialOffset :: Int
initialOffset = Int
40

shiftForWrite :: Int
shiftForWrite :: Int
shiftForWrite = Int
32

enc :: WriteBuffer -> ReadBuffer -> IO Int
enc :: WriteBuffer -> ReadBuffer -> IO Int
enc WriteBuffer{Buffer
IORef Buffer
start :: Buffer
limit :: Buffer
offset :: IORef Buffer
oldoffset :: IORef Buffer
start :: WriteBuffer -> Buffer
limit :: WriteBuffer -> Buffer
offset :: WriteBuffer -> IORef Buffer
oldoffset :: WriteBuffer -> IORef Buffer
..} ReadBuffer
rbuf = do
    Buffer
beg <- IORef Buffer -> IO Buffer
forall a. IORef a -> IO a
readIORef IORef Buffer
offset
    Buffer
end <- (Buffer, Word64, Int) -> IO Buffer
go (Buffer
beg, Word64
0, Int
initialOffset)
    IORef Buffer -> Buffer -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Buffer
offset Buffer
end
    let len :: Int
len = Buffer
end Buffer -> Buffer -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Buffer
beg
    Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
len
  where
    go :: (Buffer, Word64, Int) -> IO Buffer
go (Buffer
dst, Word64
encoded, Int
off) = do
        Int
i <- ReadBuffer -> IO Int
forall a. Readable a => a -> IO Int
readInt8 ReadBuffer
rbuf
        if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0
            then Buffer -> (Word64, Int) -> IO (Buffer, Word64, Int)
forall {a}.
(Integral a, Bits a) =>
Buffer -> (a, Int) -> IO (Buffer, a, Int)
cpy Buffer
dst (Int -> (Word64, Int)
bond Int
i) IO (Buffer, Word64, Int)
-> ((Buffer, Word64, Int) -> IO Buffer) -> IO Buffer
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Buffer, Word64, Int) -> IO Buffer
go
            else
                if Int
off Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
initialOffset
                    then Buffer -> IO Buffer
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Buffer
dst
                    else do
                        let (Word64
encoded1, Int
_) = Int -> (Word64, Int)
bond Int
idxEos
                        Buffer -> Word64 -> IO Buffer
forall {p} {b}. (Integral p, Bits p) => Buffer -> p -> IO (Ptr b)
write Buffer
dst Word64
encoded1
      where
        {-# INLINE bond #-}
        bond :: Int -> (Word64, Int)
bond Int
i = (Word64
encoded', Int
off')
          where
            len :: Int
len = UArray Int Int
huffmanLength UArray Int Int -> Int -> Int
forall i. Ix i => UArray i Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
`unsafeAt` Int
i
            code :: Word64
code = UArray Int Word64
huffmanCode UArray Int Word64 -> Int -> Word64
forall i. Ix i => UArray i Word64 -> Int -> Word64
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
`unsafeAt` Int
i
            scode :: Word64
scode = Word64
code Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftL` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
            encoded' :: Word64
encoded' = Word64
encoded Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word64
scode
            off' :: Int
off' = Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        {-# INLINE write #-}
        write :: Buffer -> p -> IO (Ptr b)
write Buffer
p p
w = do
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Buffer
p Buffer -> Buffer -> Bool
forall a. Ord a => a -> a -> Bool
>= Buffer
limit) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ BufferOverrun -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO BufferOverrun
BufferOverrun
            let w8 :: Word8
w8 = p -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (p
w p -> Int -> p
forall a. Bits a => a -> Int -> a
`shiftR` Int
shiftForWrite) :: Word8
            Buffer -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Buffer
p Word8
w8
            let p' :: Ptr b
p' = Buffer
p Buffer -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
            Ptr b -> IO (Ptr b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr b
forall {b}. Ptr b
p'
        {-# INLINE cpy #-}
        cpy :: Buffer -> (a, Int) -> IO (Buffer, a, Int)
cpy Buffer
p (a
w, Int
o)
            | Int
o Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
shiftForWrite = (Buffer, a, Int) -> IO (Buffer, a, Int)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Buffer
p, a
w, Int
o)
            | Bool
otherwise = do
                Buffer
p' <- Buffer -> a -> IO Buffer
forall {p} {b}. (Integral p, Bits p) => Buffer -> p -> IO (Ptr b)
write Buffer
p a
w
                let w' :: a
w' = a
w a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
8
                    o' :: Int
o' = Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
8
                Buffer -> (a, Int) -> IO (Buffer, a, Int)
cpy Buffer
p' (a
w', Int
o')

-- | Huffman encoding with a temporary buffer whose size is 4096.
encodeHuffman :: ByteString -> IO ByteString
encodeHuffman :: ByteString -> IO ByteString
encodeHuffman ByteString
bs = Int -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer Int
4096 ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf ->
    IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ WriteBuffer -> ByteString -> IO Int
encodeH WriteBuffer
wbuf ByteString
bs