{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnliftedFFITypes #-}

{- | Compress a contiguous sequence of bytes into a single LZ4 block.
These functions do not perform any framing.
-}
module Lz4.Block
  ( -- * Compression
    compress
  , compressU
  , compressHighly
  , compressHighlyU

    -- * Decompression
  , decompress
  , decompressU

    -- * Unsafe Compression
  , compressInto

    -- * Computing buffer size
  , requiredBufferSize
  ) where

import Lz4.Internal (c_hs_compress_HC, requiredBufferSize)

import Control.Monad.ST (runST)
import Control.Monad.ST.Run (runByteArrayST)
import Data.Bytes.Types (Bytes (Bytes))
import Data.Primitive (ByteArray (..), MutableByteArray (..))
import GHC.Exts (ByteArray#, MutableByteArray#)
import GHC.IO (unsafeIOToST)
import GHC.ST (ST (ST))

import qualified Control.Exception
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts

{- | Compress bytes using LZ4's HC algorithm. This is slower
than 'compress' but provides better compression. A higher
compression level increases compression but decreases speed.
This function has undefined behavior on byte sequences larger
than 2,113,929,216 bytes. This calls @LZ4_compress_HC@.
-}
compressHighly ::
  -- | Compression level (Use 9 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  Bytes
compressHighly :: Int -> Bytes -> Bytes
compressHighly !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s Bytes) -> Bytes
forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
actualSz
  ByteArray
result <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst
  Bytes -> ST s Bytes
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
result Int
0 Int
actualSz)

-- | Variant of 'compressHighly' with an unsliced result.
compressHighlyU ::
  -- | Compression level (Use 9 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  ByteArray
compressHighlyU :: Int -> Bytes -> ByteArray
compressHighlyU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
actualSz
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

{- | Compress bytes using LZ4.
A higher acceleration factor increases speed but decreases
compression. This function has undefined
behavior on byte sequences larger than 2,113,929,216 bytes.
This calls @LZ4_compress_default@.
-}
compress ::
  -- | Acceleration Factor (Use 1 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  Bytes
compress :: Int -> Bytes -> Bytes
compress !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s Bytes) -> Bytes
forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_fast ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
actualSz
  ByteArray
result <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst
  Bytes -> ST s Bytes
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
result Int
0 Int
actualSz)

{- | Compress bytes using LZ4, pasting the compressed bytes into the
mutable byte array at the specified offset.

Precondition: There must be at least
@'requiredBufferSize' (Bytes.length src)@ bytes available starting
from the offset in the destination buffer. This is checked, and
this function will throw an exception if this invariant is violated.
-}
compressInto ::
  -- | Acceleration Factor (Use 1 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  -- | Destination buffer
  MutableByteArray s ->
  -- | Offset into destination buffer
  Int ->
  -- | Bytes remaining in destination buffer
  Int ->
  -- | Next available offset in destination buffer
  ST s Int
compressInto :: forall s.
Int -> Bytes -> MutableByteArray s -> Int -> Int -> ST s Int
compressInto !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) !Int
doff !Int
dlen = do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  if Int
dlen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxSz
    then IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (Lz4BufferTooSmall -> IO Int
forall e a. Exception e => e -> IO a
Control.Exception.throwIO Lz4BufferTooSmall
Lz4BufferTooSmall)
    else do
      Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_fast ByteArray#
arr Int
off MutableByteArray# s
dst# Int
doff Int
len Int
maxSz Int
lvl)
      Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
doff Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
actualSz)

-- | Variant of 'compress' with an unsliced result.
compressU ::
  -- | Acceleration Factor (Use 1 if uncertain)
  Int ->
  -- | Bytes to compress
  Bytes ->
  ByteArray
compressU :: Int -> Bytes -> ByteArray
compressU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s ByteArray) -> ByteArray
runByteArrayST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_fast ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst Int
actualSz
  MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst

{- | Decompress a byte sequence. Fails if the actual decompressed
result does not match the given expected length.
-}
decompress ::
  -- | Expected length of decompressed bytes
  Int ->
  -- | Compressed bytes
  Bytes ->
  Maybe Bytes
decompress :: Int -> Bytes -> Maybe Bytes
decompress !Int
dstSz !Bytes
b = case Int -> Bytes -> Maybe ByteArray
decompressU Int
dstSz Bytes
b of
  Maybe ByteArray
Nothing -> Maybe Bytes
forall a. Maybe a
Nothing
  Just ByteArray
r -> Bytes -> Maybe Bytes
forall a. a -> Maybe a
Just (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
r Int
0 Int
dstSz)

-- | Variant of 'decompress' with an unsliced result.
decompressU ::
  -- | Expected length of decompressed bytes
  Int ->
  -- | Compressed bytes
  Bytes ->
  Maybe ByteArray
decompressU :: Int -> Bytes -> Maybe ByteArray
decompressU Int
dstSz (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s (Maybe ByteArray)) -> Maybe ByteArray
forall a. (forall s. ST s a) -> a
runST do
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
dstSz
  Int
actualSz <- IO Int -> ST s Int
forall a s. IO a -> ST s a
unsafeIOToST (ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> IO Int
forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> IO Int
c_hs_decompress_safe ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
dstSz)
  if Int
actualSz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
dstSz
    then do
      ByteArray
result <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst
      Maybe ByteArray -> ST s (Maybe ByteArray)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Maybe ByteArray
forall a. a -> Maybe a
Just ByteArray
result)
    else Maybe ByteArray -> ST s (Maybe ByteArray)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteArray
forall a. Maybe a
Nothing

foreign import ccall unsafe "hs_compress_fast"
  c_hs_compress_fast ::
    ByteArray# -> -- Source
    Int -> -- Source offset
    MutableByteArray# s -> -- Destination
    Int -> -- Destination offset
    Int -> -- Input size
    Int -> -- Destination capacity
    Int -> -- Acceleration factor
    IO Int -- Result length

foreign import ccall unsafe "hs_decompress_safe"
  c_hs_decompress_safe ::
    ByteArray# -> -- Source
    Int -> -- Source offset
    MutableByteArray# s -> -- Destination
    Int -> -- Destination offset
    Int -> -- Input size
    Int -> -- Destination capacity
    IO Int -- Result length

data Lz4BufferTooSmall = Lz4BufferTooSmall
  deriving stock (Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
(Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool)
-> (Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool)
-> Eq Lz4BufferTooSmall
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
== :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
$c/= :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
/= :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
Eq, Int -> Lz4BufferTooSmall -> ShowS
[Lz4BufferTooSmall] -> ShowS
Lz4BufferTooSmall -> String
(Int -> Lz4BufferTooSmall -> ShowS)
-> (Lz4BufferTooSmall -> String)
-> ([Lz4BufferTooSmall] -> ShowS)
-> Show Lz4BufferTooSmall
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Lz4BufferTooSmall -> ShowS
showsPrec :: Int -> Lz4BufferTooSmall -> ShowS
$cshow :: Lz4BufferTooSmall -> String
show :: Lz4BufferTooSmall -> String
$cshowList :: [Lz4BufferTooSmall] -> ShowS
showList :: [Lz4BufferTooSmall] -> ShowS
Show)
  deriving anyclass (Show Lz4BufferTooSmall
Typeable Lz4BufferTooSmall
(Typeable Lz4BufferTooSmall, Show Lz4BufferTooSmall) =>
(Lz4BufferTooSmall -> SomeException)
-> (SomeException -> Maybe Lz4BufferTooSmall)
-> (Lz4BufferTooSmall -> String)
-> Exception Lz4BufferTooSmall
SomeException -> Maybe Lz4BufferTooSmall
Lz4BufferTooSmall -> String
Lz4BufferTooSmall -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: Lz4BufferTooSmall -> SomeException
toException :: Lz4BufferTooSmall -> SomeException
$cfromException :: SomeException -> Maybe Lz4BufferTooSmall
fromException :: SomeException -> Maybe Lz4BufferTooSmall
$cdisplayException :: Lz4BufferTooSmall -> String
displayException :: Lz4BufferTooSmall -> String
Control.Exception.Exception)

-- foreign import capi "lz4.h value sizeof(LZ4_stream_t)" lz4StreamSz :: Int
--
-- allocateLz4StreamT :: ST s (MutableByteArray s)
-- allocateLz4StreamT = PM.newPinnedByteArray lz4StreamSz