{-# LANGUAGE MagicHash        #-}
{-# LANGUAGE UnliftedFFITypes #-}

module Data.Bit.SIMD
  ( ompPopcount
  , ompCom
  , ompAnd
  , ompIor
  , ompXor
  , ompAndn
  , ompIorn
  , ompNand
  , ompNior
  , ompXnor
  , reverseBitsC
  , bitIndexC
  , nthBitIndexC
  , selectBitsC
  ) where

import Control.Monad.ST
import Control.Monad.ST.Unsafe (unsafeIOToST)
import Data.Primitive.ByteArray
import GHC.Exts

foreign import ccall unsafe "_hs_bitvec_popcount"
  omp_popcount :: ByteArray# -> Int# -> Int#

-- | SIMD optimized popcount. The length is in 32 bit words.
ompPopcount :: ByteArray -> Int -> Int
ompPopcount :: ByteArray -> Int -> Int
ompPopcount (ByteArray ByteArray#
arg#) (I# Int#
len#) =
  Int# -> Int
I# (ByteArray# -> Int# -> Int#
omp_popcount ByteArray#
arg# Int#
len#)
{-# INLINE ompPopcount #-}

foreign import ccall unsafe "_hs_bitvec_com"
  omp_com :: MutableByteArray# s -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise complement. The length is in bytes
-- and the result array should have at least that many bytes.
ompCom :: MutableByteArray s -> ByteArray -> Int -> ST s ()
ompCom :: forall s. MutableByteArray s -> ByteArray -> Int -> ST s ()
ompCom (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s. MutableByteArray# s -> ByteArray# -> Int# -> IO ()
omp_com MutableByteArray# s
res# ByteArray#
arg# Int#
len#)
{-# INLINE ompCom #-}

foreign import ccall unsafe "_hs_bitvec_and"
  omp_and :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise AND. The length is in bytes
-- and the result array should have at least that many bytes.
ompAnd :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompAnd :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompAnd (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_and MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompAnd #-}

foreign import ccall unsafe "_hs_bitvec_ior"
  omp_ior :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise OR. The length is in bytes
-- and the result array should have at least that many bytes.
ompIor :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompIor :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompIor (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_ior MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompIor #-}

foreign import ccall unsafe "_hs_bitvec_xor"
  omp_xor :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise XOR. The length is in bytes
-- and the result array should have at least that many bytes.
ompXor :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompXor :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompXor (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_xor MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompXor #-}

foreign import ccall unsafe "_hs_bitvec_andn"
  omp_andn :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise AND with the second argument inverted. The length is in bytes
-- and the result array should have at least that many bytes.
ompAndn :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompAndn :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompAndn (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_andn MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompAndn #-}

foreign import ccall unsafe "_hs_bitvec_iorn"
  omp_iorn :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise OR with the second argument inverted. The length is in bytes
-- and the result array should have at least that many bytes.
ompIorn :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompIorn :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompIorn (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_iorn MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompIorn #-}

foreign import ccall unsafe "_hs_bitvec_nand"
  omp_nand :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise NAND. The length is in bytes
-- and the result array should have at least that many bytes.
ompNand :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompNand :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompNand (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_nand MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompNand #-}

foreign import ccall unsafe "_hs_bitvec_nior"
  omp_nior :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise NOR. The length is in bytes
-- and the result array should have at least that many bytes.
ompNior :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompNior :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompNior (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_nior MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompNior #-}

foreign import ccall unsafe "_hs_bitvec_xnor"
  omp_xnor :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()

-- | SIMD optimized bitwise XNOR. The length is in bytes
-- and the result array should have at least that many bytes.
ompXnor :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompXnor :: forall s.
MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompXnor (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> IO ()
omp_xnor MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len#)
{-# INLINE ompXnor #-}

foreign import ccall unsafe "_hs_bitvec_reverse_bits"
  reverse_bits :: MutableByteArray# s -> ByteArray# -> Int# -> IO ()

-- | The length is in words.
reverseBitsC :: MutableByteArray s -> ByteArray -> Int -> ST s ()
reverseBitsC :: forall s. MutableByteArray s -> ByteArray -> Int -> ST s ()
reverseBitsC (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg#) (I# Int#
len#) =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s. MutableByteArray# s -> ByteArray# -> Int# -> IO ()
reverse_bits MutableByteArray# s
res# ByteArray#
arg# Int#
len#)
{-# INLINE reverseBitsC #-}

foreign import ccall unsafe "_hs_bitvec_bit_index"
  bit_index :: ByteArray# -> Int# -> Bool -> Int#

bitIndexC :: ByteArray -> Int -> Bool -> Int
bitIndexC :: ByteArray -> Int -> Bool -> Int
bitIndexC (ByteArray ByteArray#
arg#) (I# Int#
len#) Bool
bit =
  Int# -> Int
I# (ByteArray# -> Int# -> Bool -> Int#
bit_index ByteArray#
arg# Int#
len# Bool
bit)
{-# INLINE bitIndexC #-}

foreign import ccall unsafe "_hs_bitvec_nth_bit_index"
  nth_bit_index :: ByteArray# -> Int# -> Bool -> Int# -> Int#

nthBitIndexC :: ByteArray -> Int -> Bool -> Int -> Int
nthBitIndexC :: ByteArray -> Int -> Bool -> Int -> Int
nthBitIndexC (ByteArray ByteArray#
arg#) (I# Int#
len#) Bool
bit (I# Int#
n#) =
  Int# -> Int
I# (ByteArray# -> Int# -> Bool -> Int# -> Int#
nth_bit_index ByteArray#
arg# Int#
len# Bool
bit Int#
n#)
{-# INLINE nthBitIndexC #-}

foreign import ccall unsafe "_hs_bitvec_select_bits"
  select_bits_c :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> Bool -> IO Int

selectBitsC :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> Bool -> ST s Int
selectBitsC :: forall s.
MutableByteArray s
-> ByteArray -> ByteArray -> Int -> Bool -> ST s Int
selectBitsC (MutableByteArray MutableByteArray# s
res#) (ByteArray ByteArray#
arg1#) (ByteArray ByteArray#
arg2#) (I# Int#
len#) Bool
exclude =
  forall a s. IO a -> ST s a
unsafeIOToST (forall s.
MutableByteArray# s
-> ByteArray# -> ByteArray# -> Int# -> Bool -> IO Int
select_bits_c MutableByteArray# s
res# ByteArray#
arg1# ByteArray#
arg2# Int#
len# Bool
exclude)
{-# INLINE selectBitsC #-}