module Codec.Compression.Zstd.Internal
(
CCtx(..)
, DCtx(..)
, compressWith
, decompressWith
, decompressedSize
, withCCtx
, withDCtx
, withDict
, trainFromSamples
, getDictID
) where
import Codec.Compression.Zstd.Types (Decompress(..), Dict(..))
import Control.Exception.Base (bracket)
import Data.ByteString.Internal (ByteString(..))
import Data.Word (Word, Word8)
import Foreign.C.Types (CInt, CSize)
import Foreign.Marshal.Array (withArray)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import System.IO.Unsafe (unsafePerformIO)
import qualified Codec.Compression.Zstd.FFI as C
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
compressWith
:: String
-> (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize)
-> Int
-> ByteString
-> IO ByteString
compressWith name compressor level (PS sfp off len)
| level < 1 || level > C.maxCLevel
= bail name "unsupported compression level"
| otherwise =
withForeignPtr sfp $ \sp -> do
maxSize <- C.compressBound (fromIntegral len)
dfp <- B.mallocByteString (fromIntegral maxSize)
withForeignPtr dfp $ \dst -> do
let src = sp `plusPtr` off
csz <- compressor dst maxSize src (fromIntegral len) (fromIntegral level)
handleError csz name $ do
let size = fromIntegral csz
if csz < 128 || csz >= maxSize `div` 2
then return (PS dfp 0 size)
else B.create size $ \p -> B.memcpy p dst size
decompressedSize :: ByteString -> Maybe Int
decompressedSize (PS fp off len) =
unsafePerformIO . withForeignPtr fp $ \ptr -> do
sz <- C.getDecompressedSize (ptr `plusPtr` off) (fromIntegral len)
return $ if sz == 0 || sz > fromIntegral (maxBound :: Int)
then Nothing
else Just (fromIntegral sz)
decompressWith :: (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize)
-> ByteString
-> IO Decompress
decompressWith decompressor (PS sfp off len) = do
withForeignPtr sfp $ \sp -> do
let src = sp `plusPtr` off
dstSize <- C.getDecompressedSize src (fromIntegral len)
if dstSize == 0
then return Skip
else if dstSize > fromIntegral (maxBound :: Int)
then return (Error "invalid compressed payload size")
else do
dfp <- B.mallocByteString (fromIntegral dstSize)
size <- withForeignPtr dfp $ \dst ->
decompressor dst (fromIntegral dstSize) src (fromIntegral len)
return $ if C.isError size
then Error (C.getErrorName size)
else Decompress (PS dfp 0 (fromIntegral size))
newtype CCtx = CCtx { getCCtx :: Ptr C.CCtx }
withCCtx :: (CCtx -> IO a) -> IO a
withCCtx act =
bracket (fmap CCtx (C.checkAlloc "withCCtx" C.createCCtx))
(C.freeCCtx . getCCtx) act
newtype DCtx = DCtx { getDCtx :: Ptr C.DCtx }
withDCtx :: (DCtx -> IO a) -> IO a
withDCtx act =
bracket (fmap DCtx (C.checkAlloc "withDCtx" C.createDCtx))
(C.freeDCtx . getDCtx) act
withDict :: Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict (Dict (PS fp off len)) act =
withForeignPtr fp $ \ptr -> act (ptr `plusPtr` off) (fromIntegral len)
trainFromSamples :: Int
-> [ByteString]
-> Either String Dict
trainFromSamples capacity samples = unsafePerformIO $
withArray (map B.length samples) $ \sizes -> do
dfp <- B.mallocByteString capacity
let PS sfp _ _ = B.concat samples
withForeignPtr dfp $ \dict ->
withForeignPtr sfp $ \sampPtr -> do
dsz <- C.trainFromBuffer
dict (fromIntegral capacity)
sampPtr (castPtr sizes) (fromIntegral (length samples))
if C.isError dsz
then return (Left (C.getErrorName dsz))
else fmap (Right . Dict) $ do
let size = fromIntegral dsz
if size < 128 || size >= capacity `div` 2
then return (PS dfp 0 size)
else B.create size $ \p -> B.memcpy p dict size
getDictID :: Dict -> Maybe Word
getDictID dict = unsafePerformIO $ do
n <- withDict dict C.getDictID
return $! if n == 0
then Nothing
else Just (fromIntegral n)
handleError :: CSize -> String -> IO a -> IO a
handleError sizeOrError func act
| C.isError sizeOrError
= bail func (C.getErrorName sizeOrError)
| otherwise = act
bail :: String -> String -> a
bail func str = error $ "Codec.Compression.Zstd." ++ func ++ ": " ++ str