{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE PatternSynonyms     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy         #-}
{-# LANGUAGE UnboxedTuples       #-}
{-# LANGUAGE ViewPatterns        #-}
module Data.Hashable.XXH3 (
    -- * One shot
    xxh3_64bit_withSeed_ptr,
    xxh3_64bit_withSeed_bs,
    xxh3_64bit_withSeed_ba,
    xxh3_64bit_withSeed_w64,
    xxh3_64bit_withSeed_w32,
    -- * Incremental
    XXH3_State,
    xxh3_64bit_createState,
    xxh3_64bit_reset_withSeed,
    xxh3_64bit_digest,
    xxh3_64bit_update_bs,
    xxh3_64bit_update_ba,
    xxh3_64bit_update_w64,
    xxh3_64bit_update_w32,
) where

import Control.Monad.ST.Unsafe  (unsafeIOToST)
import Data.Array.Byte          (ByteArray (..), MutableByteArray (..))
import Data.ByteString.Internal (ByteString (..), accursedUnutterablePerformIO)
import Data.Word                (Word32, Word64, Word8)
import Foreign                  (Ptr)
import GHC.Exts                 (Int (..), MutableByteArray#, newAlignedPinnedByteArray#)
import GHC.ST                   (ST (..))

import Data.Hashable.FFI

#if MIN_VERSION_base(4,15,0)
import GHC.ForeignPtr (unsafeWithForeignPtr)
#else
import Foreign (ForeignPtr, withForeignPtr)
#endif

#if MIN_VERSION_bytestring(0,11,0)
#else
import Foreign (ForeignPtr, plusForeignPtr)
#endif

#if !MIN_VERSION_base(4,15,0)
unsafeWithForeignPtr :: ForeignPtr a -> (Ptr a -> IO b) -> IO b
unsafeWithForeignPtr = withForeignPtr
#endif

#if MIN_VERSION_bytestring(0,11,0)
#else
pattern BS :: ForeignPtr Word8 -> Int -> ByteString
pattern BS fptr len <- (matchBS -> (fptr,len))
  where BS fptr len = PS fptr 0 len
{-# COMPLETE BS #-}

matchBS :: ByteString -> (ForeignPtr Word8, Int)
matchBS (PS fptr off len) = (plusForeignPtr fptr off, len)
#endif

-------------------------------------------------------------------------------
-- OneShot
-------------------------------------------------------------------------------

-- | Hash 'Ptr'
xxh3_64bit_withSeed_ptr :: Ptr Word8 -> Int -> Word64 -> IO Word64
xxh3_64bit_withSeed_ptr :: Ptr Word8 -> Int -> Word64 -> IO Word64
xxh3_64bit_withSeed_ptr !Ptr Word8
ptr !Int
len !Word64
salt =
    Ptr Word8 -> CSize -> Word64 -> IO Word64
unsafe_xxh3_64bit_withSeed_ptr Ptr Word8
ptr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Word64
salt

-- | Hash 'ByteString'.
xxh3_64bit_withSeed_bs :: ByteString -> Word64 -> Word64
xxh3_64bit_withSeed_bs :: ByteString -> Word64 -> Word64
xxh3_64bit_withSeed_bs (BS ForeignPtr Word8
fptr Int
len) !Word64
salt = IO Word64 -> Word64
forall a. IO a -> a
accursedUnutterablePerformIO (IO Word64 -> Word64) -> IO Word64 -> Word64
forall a b. (a -> b) -> a -> b
$
    ForeignPtr Word8 -> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
unsafeWithForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO Word64) -> IO Word64)
-> (Ptr Word8 -> IO Word64) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
    Ptr Word8 -> CSize -> Word64 -> IO Word64
unsafe_xxh3_64bit_withSeed_ptr Ptr Word8
ptr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Word64
salt

-- | Hash (part of) 'ByteArray'.
xxh3_64bit_withSeed_ba :: ByteArray -> Int -> Int -> Word64 -> Word64
xxh3_64bit_withSeed_ba :: ByteArray -> Int -> Int -> Word64 -> Word64
xxh3_64bit_withSeed_ba (ByteArray ByteArray#
ba) !Int
off !Int
len !Word64
salt =
    ByteArray# -> CSize -> CSize -> Word64 -> Word64
unsafe_xxh3_64bit_withSeed_ba ByteArray#
ba (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Word64
salt

-- | Hash 'Word64'.
xxh3_64bit_withSeed_w64 :: Word64 -> Word64 -> Word64
xxh3_64bit_withSeed_w64 :: Word64 -> Word64 -> Word64
xxh3_64bit_withSeed_w64 !Word64
x !Word64
salt =
    Word64 -> Word64 -> Word64
unsafe_xxh3_64bit_withSeed_u64 Word64
x Word64
salt

-- | Hash 'Word32'.
xxh3_64bit_withSeed_w32 :: Word32 -> Word64 -> Word64
xxh3_64bit_withSeed_w32 :: Word32 -> Word64 -> Word64
xxh3_64bit_withSeed_w32 !Word32
x !Word64
salt =
    Word32 -> Word64 -> Word64
unsafe_xxh3_64bit_withSeed_u32 Word32
x Word64
salt

-------------------------------------------------------------------------------
-- Incremental
-------------------------------------------------------------------------------

-- | Mutable XXH3 state.
data XXH3_State s = XXH3 (MutableByteArray# s)

-- | Create 'XXH3_State'.
xxh3_64bit_createState :: forall s. ST s (XXH3_State s)
xxh3_64bit_createState :: forall s. ST s (XXH3_State s)
xxh3_64bit_createState = do
    -- aligned alloc, otherwise we get segfaults.
    -- see XXH3_createState implementation
    MutableByteArray MutableByteArray# s
ba <- Int -> Int -> ST s (MutableByteArray s)
forall s. Int -> Int -> ST s (MutableByteArray s)
newAlignedPinnedByteArray Int
unsafe_xxh3_sizeof_state Int
64
    IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (MutableByteArray# s -> IO ()
forall s. MutableByteArray# s -> IO ()
unsafe_xxh3_initState MutableByteArray# s
ba)
    XXH3_State s -> ST s (XXH3_State s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray# s -> XXH3_State s
forall s. MutableByteArray# s -> XXH3_State s
XXH3 MutableByteArray# s
ba)

-- | Reset 'XXH3_State' with a seed.
xxh3_64bit_reset_withSeed :: XXH3_State s -> Word64 -> ST s ()
xxh3_64bit_reset_withSeed :: forall s. XXH3_State s -> Word64 -> ST s ()
xxh3_64bit_reset_withSeed (XXH3 MutableByteArray# s
s) Word64
seed = do
    IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (MutableByteArray# s -> Word64 -> IO ()
forall s. MutableByteArray# s -> Word64 -> IO ()
unsafe_xxh3_64bit_reset_withSeed MutableByteArray# s
s Word64
seed)

-- | Return a hash value from a 'XXH3_State'.
--
-- Doesn't mutate given state, so you can update, digest and update again.
xxh3_64bit_digest :: XXH3_State s -> ST s Word64
xxh3_64bit_digest :: forall s. XXH3_State s -> ST s Word64
xxh3_64bit_digest (XXH3 MutableByteArray# s
s) =
    IO Word64 -> ST s Word64
forall a s. IO a -> ST s a
unsafeIOToST (MutableByteArray# s -> IO Word64
forall s. MutableByteArray# s -> IO Word64
unsafe_xxh3_64bit_digest MutableByteArray# s
s)

-- | Update 'XXH3_State' with 'ByteString'.
xxh3_64bit_update_bs :: XXH3_State s -> ByteString -> ST s ()
xxh3_64bit_update_bs :: forall s. XXH3_State s -> ByteString -> ST s ()
xxh3_64bit_update_bs (XXH3 MutableByteArray# s
s) (BS ForeignPtr Word8
fptr Int
len) = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
unsafeWithForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
    MutableByteArray# s -> Ptr Word8 -> CSize -> IO ()
forall s. MutableByteArray# s -> Ptr Word8 -> CSize -> IO ()
unsafe_xxh3_64bit_update_ptr MutableByteArray# s
s Ptr Word8
ptr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)

-- | Update 'XXH3_State' with (part of) 'ByteArray'
xxh3_64bit_update_ba :: XXH3_State s -> ByteArray -> Int -> Int -> ST s ()
xxh3_64bit_update_ba :: forall s. XXH3_State s -> ByteArray -> Int -> Int -> ST s ()
xxh3_64bit_update_ba (XXH3 MutableByteArray# s
s) (ByteArray ByteArray#
ba) !Int
off !Int
len = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$
    MutableByteArray# s -> ByteArray# -> CSize -> CSize -> IO ()
forall s.
MutableByteArray# s -> ByteArray# -> CSize -> CSize -> IO ()
unsafe_xxh3_64bit_update_ba MutableByteArray# s
s ByteArray#
ba (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)

-- | Update 'XXH3_State' with 'Word64'.
xxh3_64bit_update_w64 :: XXH3_State s -> Word64 -> ST s ()
xxh3_64bit_update_w64 :: forall s. XXH3_State s -> Word64 -> ST s ()
xxh3_64bit_update_w64 (XXH3 MutableByteArray# s
s) Word64
w64 = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$
    MutableByteArray# s -> Word64 -> IO ()
forall s. MutableByteArray# s -> Word64 -> IO ()
unsafe_xxh3_64bit_update_u64 MutableByteArray# s
s Word64
w64

-- | Update 'XXH3_State' with 'Word32'.
xxh3_64bit_update_w32 :: XXH3_State s -> Word32 -> ST s ()
xxh3_64bit_update_w32 :: forall s. XXH3_State s -> Word32 -> ST s ()
xxh3_64bit_update_w32 (XXH3 MutableByteArray# s
s) Word32
w32 = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$
    MutableByteArray# s -> Word32 -> IO ()
forall s. MutableByteArray# s -> Word32 -> IO ()
unsafe_xxh3_64bit_update_u32 MutableByteArray# s
s Word32
w32

-------------------------------------------------------------------------------
-- mini-primitive
-------------------------------------------------------------------------------

newAlignedPinnedByteArray
    :: Int  -- ^ size
    -> Int  -- ^ alignment
    -> ST s (MutableByteArray s)
{-# INLINE newAlignedPinnedByteArray #-}
newAlignedPinnedByteArray :: forall s. Int -> Int -> ST s (MutableByteArray s)
newAlignedPinnedByteArray (I# Int#
n) (I# Int#
k) =
    STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall s a. STRep s a -> ST s a
ST (\State# s
s -> case Int# -> Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
n Int#
k State# s
s of (# State# s
s', MutableByteArray# s
arr #) -> (# State# s
s', MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
arr #))