{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE UnliftedFFITypes #-}

{- | Compress a contiguous sequence of bytes into an LZ4 frame
containing a single block.
-}
module Lz4.Frame
  ( -- * Compression
    compressHighlyU
  ) where

import Lz4.Internal (c_hs_compress_HC, requiredBufferSize)

import Control.Monad.ST (runST)
import Control.Monad.ST.Run (runByteArrayST)
import Data.Bytes.Types (Bytes (Bytes))
import Data.Int (Int32)
import Data.Primitive (ByteArray (..), MutableByteArray (..))
import Data.Word (Word8)
import GHC.Exts (ByteArray#, MutableByteArray#)
import GHC.IO (unsafeIOToST)
import GHC.ST (ST (ST))

import qualified Control.Exception
import qualified Data.Primitive as PM
import qualified Data.Primitive.ByteArray.LittleEndian as LE
import qualified GHC.Exts as Exts

{- | Use HC compression to produce a frame with a single block.
All optional fields (checksums, content sizes, and dictionary IDs)
are omitted.

Note: Currently, this produces incorrect output when the size of
the input to be compressed is greater than 4MiB. The only way
to correct this function is to make it not compress large input.
This can be done by setting the high bit of the size. This needs
to be tested though since it is an uncommon code path.
-}
compressHighlyU ::
  -- | Compression level (Use 9 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  ByteArray
compressHighlyU :: Int -> Bytes -> ByteArray
compressHighlyU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
15
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  -- -- First 4 bytes: magic identifier
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
0 (Word8
0x04 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
1 (Word8
0x22 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
2 (Word8
0x4D :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
3 (Word8
0x18 :: Word8)
  -- Next 3 bytes: frame descriptor
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
4 (Word8
0b0110_0000 :: Word8)
  if
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
65_536 -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0100_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0x82 :: Word8)
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
262_144 -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0101_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0xFB :: Word8)
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1_048_576 -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0110_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0x51 :: Word8)
    | Bool
otherwise -> do
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
5 (Word8
0b0111_0000 :: Word8)
        MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
6 (Word8
0x73 :: Word8)
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
11 Int
len Int
maxSz Int
lvl)
  MutableByteArray (PrimState (ST s)) -> Int -> Int32 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, PrimUnaligned a, Bytes a) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
LE.writeUnalignedByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
7 (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
actualSz :: Int32)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
11) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
12) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
13) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
14) (Word8
0x00 :: Word8)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst (Int
actualSz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
15)
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst