{-# LANGUAGE BangPatterns, MagicHash #-}
{-# LANGUAGE CPP #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.CharSet.ByteSet
-- Copyright   :  Edward Kmett 2011
--                Bryan O'Sullivan 2008
-- License     :  BSD3
--
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  non-portable (BangPatterns, MagicHash)
--
-- Fast set membership tests for byte values. The set representation is
-- unboxed for efficiency and uses one bit per byte to represent the presence
-- or absence of a byte in the set.
--
-- This is a fairly minimal API. You probably want to use CharSet.
-----------------------------------------------------------------------------
module Data.CharSet.ByteSet
    (
    -- * Data type
      ByteSet(..)
    -- * Construction
    , fromList
    -- * Lookup
    , member
    ) where

import Data.Bits ((.&.), (.|.))
import Foreign.Storable (peekByteOff, pokeByteOff)
import GHC.Exts ( Int(I#), Word#, iShiftRA#, shiftL#
#if MIN_VERSION_base(4,16,0)
                , Word8#, word8ToWord#, wordToWord8#
#else
                , narrow8Word#
#endif
                )
import GHC.Word (Word8(W8#))
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as I
import qualified Data.ByteString.Unsafe as U

#if MIN_VERSION_base(4,8,0)
import Foreign.Marshal.Utils (fillBytes)
#endif

newtype ByteSet = ByteSet B.ByteString deriving (ByteSet -> ByteSet -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ByteSet -> ByteSet -> Bool
$c/= :: ByteSet -> ByteSet -> Bool
== :: ByteSet -> ByteSet -> Bool
$c== :: ByteSet -> ByteSet -> Bool
Eq, Eq ByteSet
ByteSet -> ByteSet -> Bool
ByteSet -> ByteSet -> Ordering
ByteSet -> ByteSet -> ByteSet
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ByteSet -> ByteSet -> ByteSet
$cmin :: ByteSet -> ByteSet -> ByteSet
max :: ByteSet -> ByteSet -> ByteSet
$cmax :: ByteSet -> ByteSet -> ByteSet
>= :: ByteSet -> ByteSet -> Bool
$c>= :: ByteSet -> ByteSet -> Bool
> :: ByteSet -> ByteSet -> Bool
$c> :: ByteSet -> ByteSet -> Bool
<= :: ByteSet -> ByteSet -> Bool
$c<= :: ByteSet -> ByteSet -> Bool
< :: ByteSet -> ByteSet -> Bool
$c< :: ByteSet -> ByteSet -> Bool
compare :: ByteSet -> ByteSet -> Ordering
$ccompare :: ByteSet -> ByteSet -> Ordering
Ord, Int -> ByteSet -> ShowS
[ByteSet] -> ShowS
ByteSet -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ByteSet] -> ShowS
$cshowList :: [ByteSet] -> ShowS
show :: ByteSet -> String
$cshow :: ByteSet -> String
showsPrec :: Int -> ByteSet -> ShowS
$cshowsPrec :: Int -> ByteSet -> ShowS
Show)

-- | Representation of the index of a bit inside a bytestring
-- in terms of a byte index and a bit index inside the byte
data I = I
    {-# UNPACK #-} !Int         -- byte index
    {-# UNPACK #-} !Word8       -- bit index

shiftR :: Int -> Int -> Int
shiftR :: Int -> Int -> Int
shiftR (I# Int#
x#) (I# Int#
i#) = Int# -> Int
I# (Int#
x# Int# -> Int# -> Int#
`iShiftRA#` Int#
i#)

shiftL :: Word8 -> Int -> Word8
shiftL :: Word8 -> Int -> Word8
shiftL (W8# Word8#
x#) (I# Int#
i#) = Word8# -> Word8
W8# (Word# -> Word8#
narrow8WordCompat# (Word8# -> Word#
word8ToWordCompat# Word8#
x# Word# -> Int# -> Word#
`shiftL#` Int#
i#))

-- | Convert a bit index to a byte index and bit index inside the byte
index :: Int -> I
index :: Int -> I
index Int
i = Int -> Word8 -> I
I (Int
i Int -> Int -> Int
`shiftR` Int
3) (Word8
1 Word8 -> Int -> Word8
`shiftL` (Int
i forall a. Bits a => a -> a -> a
.&. Int
7))
{-# INLINE index #-}

fromList :: [Word8] -> ByteSet
fromList :: [Word8] -> ByteSet
fromList [Word8]
s0 = ByteString -> ByteSet
ByteSet forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Word8 -> IO ()) -> ByteString
I.unsafeCreate Int
32 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
t -> do
  ()
_ <-
#if MIN_VERSION_base(4,8,0)
    forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes Ptr Word8
t Word8
0 Int
32
#else
    I.memset t 0 32
#endif
  let go :: [a] -> IO ()
go [] = forall (m :: * -> *) a. Monad m => a -> m a
return ()
      go (a
c:[a]
cs) = do
        Word8
prev <- forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
t Int
byte :: IO Word8
        forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
t Int
byte (Word8
prev forall a. Bits a => a -> a -> a
.|. Word8
bit)
        [a] -> IO ()
go [a]
cs
        where I Int
byte Word8
bit = Int -> I
index (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
c)
  forall {a}. Integral a => [a] -> IO ()
go [Word8]
s0

-- | Check the set for membership.
member :: Word8 -> ByteSet -> Bool
member :: Word8 -> ByteSet -> Bool
member Word8
w (ByteSet ByteString
t) = ByteString -> Int -> Word8
U.unsafeIndex ByteString
t Int
byte forall a. Bits a => a -> a -> a
.&. Word8
bit forall a. Eq a => a -> a -> Bool
/= Word8
0
  where
    I Int
byte Word8
bit = Int -> I
index (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w)

#if MIN_VERSION_base(4,16,0)
word8ToWordCompat# :: Word8# -> Word#
word8ToWordCompat# :: Word8# -> Word#
word8ToWordCompat# = Word8# -> Word#
word8ToWord#

narrow8WordCompat# :: Word# -> Word8#
narrow8WordCompat# :: Word# -> Word8#
narrow8WordCompat# = Word# -> Word8#
wordToWord8#
#else
word8ToWordCompat# :: Word# -> Word#
word8ToWordCompat# x = x

narrow8WordCompat# :: Word# -> Word#
narrow8WordCompat# = narrow8Word#
#endif