{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}
#ifndef BITVEC_THREADSAFE
module Data.Bit.Internal
#else
module Data.Bit.InternalTS
#endif
( Bit(..)
, U.Vector(BitVec)
, U.MVector(BitMVec)
, indexWord
, readWord
, writeWord
, unsafeFlipBit
, flipBit
, WithInternals(..)
, modifyByteArray
) where
#include "vector.h"
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Bits
import Data.Bit.Utils
import Data.Primitive.ByteArray
import Data.Ratio
import Data.Typeable
import qualified Data.Vector.Generic as V
import qualified Data.Vector.Generic.Mutable as MV
import qualified Data.Vector.Unboxed as U
import GHC.Generics
#ifdef BITVEC_THREADSAFE
import GHC.Exts
#endif
#ifndef BITVEC_THREADSAFE
newtype Bit = Bit { unBit :: Bool }
deriving (Bounded, Enum, Eq, Ord, FiniteBits, Bits, Typeable, Generic, NFData)
#else
newtype Bit = Bit { unBit :: Bool }
deriving (Bounded, Enum, Eq, Ord, FiniteBits, Bits, Typeable, Generic, NFData)
#endif
instance Num Bit where
Bit a * Bit b = Bit (a && b)
Bit a + Bit b = Bit (a /= b)
Bit a - Bit b = Bit (a /= b)
negate = id
abs = id
signum = id
fromInteger = Bit . odd
instance Real Bit where
toRational (Bit False) = 0
toRational (Bit True) = 1
instance Integral Bit where
quotRem _ (Bit False) = throw DivideByZero
quotRem x (Bit True) = (x, Bit False)
quot _ (Bit False) = throw DivideByZero
quot x (Bit True) = x
rem _ (Bit False) = throw DivideByZero
rem _ (Bit True) = Bit False
divMod = quotRem
div = quot
mod = rem
toInteger (Bit False) = 0
toInteger (Bit True) = 1
instance Fractional Bit where
fromRational x = fromInteger (numerator x) `quot` fromInteger (denominator x)
_ / Bit False = throw DivideByZero
x / Bit True = x
recip (Bit False) = throw DivideByZero
recip (Bit True) = Bit True
instance Show Bit where
showsPrec _ (Bit False) = showString "0"
showsPrec _ (Bit True ) = showString "1"
instance Read Bit where
readsPrec p (' ' : rest) = readsPrec p rest
readsPrec _ ('0' : rest) = [(Bit False, rest)]
readsPrec _ ('1' : rest) = [(Bit True, rest)]
readsPrec _ _ = []
instance U.Unbox Bit
data instance U.MVector s Bit = BitMVec !Int !Int !(MutableByteArray s)
data instance U.Vector Bit = BitVec !Int !Int !ByteArray
newtype WithInternals = WithInternals (U.Vector Bit)
#if MIN_VERSION_primitive(0,6,3)
instance Show WithInternals where
show (WithInternals v@(BitVec off len ba)) = show (off, len, ba, v)
#endif
readBit :: Int -> Word -> Bit
readBit i w = Bit (w .&. (1 `unsafeShiftL` i) /= 0)
{-# INLINE readBit #-}
extendToWord :: Bit -> Word
extendToWord (Bit False) = 0
extendToWord (Bit True ) = complement 0
indexWord :: U.Vector Bit -> Int -> Word
indexWord !(BitVec _ 0 _) _ = 0
indexWord !(BitVec off len' arr) !i' = word
where
len = off + len'
i = off + i'
nMod = modWordSize i
loIx = divWordSize i
loWord = indexByteArray arr loIx
hiWord = indexByteArray arr (loIx + 1)
word = if nMod == 0
then loWord
else if loIx == divWordSize (len - 1)
then (loWord `unsafeShiftR` nMod)
else
(loWord `unsafeShiftR` nMod)
.|. (hiWord `unsafeShiftL` (wordSize - nMod))
{-# INLINE indexWord #-}
readWord :: PrimMonad m => U.MVector (PrimState m) Bit -> Int -> m Word
readWord !(BitMVec _ 0 _) _ = pure 0
readWord !(BitMVec off len' arr) !i' = do
let len = off + len'
i = off + i'
nMod = modWordSize i
loIx = divWordSize i
loWord <- readByteArray arr loIx
if nMod == 0
then pure loWord
else if loIx == divWordSize (len - 1)
then pure (loWord `unsafeShiftR` nMod)
else do
hiWord <- readByteArray arr (loIx + 1)
pure
$ (loWord `unsafeShiftR` nMod)
.|. (hiWord `unsafeShiftL` (wordSize - nMod))
#if __GLASGOW_HASKELL__ >= 800
{-# SPECIALIZE readWord :: U.MVector s Bit -> Int -> ST s Word #-}
#endif
{-# INLINE readWord #-}
modifyByteArray
:: PrimMonad m
=> MutableByteArray (PrimState m)
-> Int
-> Word
-> Word
-> m ()
#ifndef BITVEC_THREADSAFE
modifyByteArray arr ix msk new = do
old <- readByteArray arr ix
writeByteArray arr ix (old .&. msk .|. new)
{-# INLINE modifyByteArray #-}
#else
modifyByteArray (MutableByteArray mba) (I# ix) (W# msk) (W# new) = do
primitive $ \state ->
let !(# state', _ #) = fetchAndIntArray# mba ix (word2Int# msk) state in
let !(# state'', _ #) = fetchOrIntArray# mba ix (word2Int# new) state' in
(# state'', () #)
#if __GLASGOW_HASKELL__ == 808 && __GLASGOW_HASKELL_PATCHLEVEL1__ == 1
{-# NOINLINE modifyByteArray #-}
#else
{-# INLINE modifyByteArray #-}
#endif
#endif
writeWord :: PrimMonad m => U.MVector (PrimState m) Bit -> Int -> Word -> m ()
writeWord !(BitMVec _ 0 _) _ _ = pure ()
writeWord !(BitMVec off len' arr) !i' !x
| iMod == 0
= if len >= i + wordSize
then writeByteArray arr iDiv x
else modifyByteArray arr iDiv (hiMask lenMod) (x .&. loMask lenMod)
| iDiv == divWordSize (len - 1)
= if lenMod == 0
then modifyByteArray arr iDiv (loMask iMod) (x `unsafeShiftL` iMod)
else modifyByteArray arr iDiv (loMask iMod .|. hiMask lenMod) ((x `unsafeShiftL` iMod) .&. loMask lenMod)
| iDiv + 1 == divWordSize (len - 1)
= do
modifyByteArray arr iDiv (loMask iMod) (x `unsafeShiftL` iMod)
if lenMod == 0
then modifyByteArray arr (iDiv + 1) (hiMask iMod) (x `unsafeShiftR` (wordSize - iMod))
else modifyByteArray arr (iDiv + 1) (hiMask iMod .|. hiMask lenMod) (x `unsafeShiftR` (wordSize - iMod) .&. loMask lenMod)
| otherwise
= do
modifyByteArray arr iDiv (loMask iMod) (x `unsafeShiftL` iMod)
modifyByteArray arr (iDiv + 1) (hiMask iMod) (x `unsafeShiftR` (wordSize - iMod))
where
len = off + len'
lenMod = modWordSize len
i = off + i'
iMod = modWordSize i
iDiv = divWordSize i
#if __GLASGOW_HASKELL__ >= 800
{-# SPECIALIZE writeWord :: U.MVector s Bit -> Int -> Word -> ST s () #-}
#endif
{-# INLINE writeWord #-}
instance MV.MVector U.MVector Bit where
{-# INLINE basicInitialize #-}
basicInitialize vec = MV.basicSet vec (Bit False)
{-# INLINE basicUnsafeNew #-}
basicUnsafeNew n
| n < 0 = error $ "Data.Bit.basicUnsafeNew: negative length: " ++ show n
| otherwise = do
arr <- newByteArray (wordsToBytes $ nWords n)
pure $ BitMVec 0 n arr
{-# INLINE basicUnsafeReplicate #-}
basicUnsafeReplicate n x
| n < 0 = error
$ "Data.Bit.basicUnsafeReplicate: negative length: "
++ show n
| otherwise = do
arr <- newByteArray (wordsToBytes $ nWords n)
setByteArray arr 0 (nWords n) (extendToWord x :: Word)
pure $ BitMVec 0 n arr
{-# INLINE basicOverlaps #-}
basicOverlaps (BitMVec i' m' arr1) (BitMVec j' n' arr2) =
sameMutableByteArray arr1 arr2
&& (between i j (j + n) || between j i (i + m))
where
i = divWordSize i'
m = nWords (i' + m') - i
j = divWordSize j'
n = nWords (j' + n') - j
between x y z = x >= y && x < z
{-# INLINE basicLength #-}
basicLength (BitMVec _ n _) = n
{-# INLINE basicUnsafeRead #-}
basicUnsafeRead (BitMVec off _ arr) !i' = do
let i = off + i'
word <- readByteArray arr (divWordSize i)
pure $ readBit (modWordSize i) word
{-# INLINE basicUnsafeWrite #-}
#ifndef BITVEC_THREADSAFE
basicUnsafeWrite (BitMVec off _ arr) !i' !x = do
let i = off + i'
j = divWordSize i
k = modWordSize i
kk = 1 `unsafeShiftL` k :: Word
word <- readByteArray arr j
writeByteArray arr j (if unBit x then word .|. kk else word .&. complement kk)
#else
basicUnsafeWrite (BitMVec off _ (MutableByteArray mba)) !i' (Bit b) = do
let i = off + i'
!(I# j) = divWordSize i
!(I# k) = 1 `unsafeShiftL` modWordSize i
primitive $ \state ->
let !(# state', _ #) =
(if b
then fetchOrIntArray# mba j k state
else fetchAndIntArray# mba j (notI# k) state
)
in (# state', () #)
#endif
{-# INLINE basicClear #-}
basicClear _ = pure ()
{-# INLINE basicSet #-}
basicSet (BitMVec _ 0 _) _ = pure ()
basicSet (BitMVec off len arr) (extendToWord -> x) | offBits == 0 =
case modWordSize len of
0 -> setByteArray arr offWords lWords (x :: Word)
nMod -> do
setByteArray arr offWords (lWords - 1) (x :: Word)
modifyByteArray arr (offWords + lWords - 1) (hiMask nMod) (x .&. loMask nMod)
where
offBits = modWordSize off
offWords = divWordSize off
lWords = nWords (offBits + len)
basicSet (BitMVec off len arr) (extendToWord -> x) =
case modWordSize (off + len) of
0 -> do
modifyByteArray arr offWords (loMask offBits) (x .&. hiMask offBits)
setByteArray arr (offWords + 1) (lWords - 1) (x :: Word)
nMod -> if lWords == 1
then do
let lohiMask = loMask offBits .|. hiMask nMod
modifyByteArray arr offWords lohiMask (x .&. complement lohiMask)
else do
modifyByteArray arr offWords (loMask offBits) (x .&. hiMask offBits)
setByteArray arr (offWords + 1) (lWords - 2) (x :: Word)
modifyByteArray arr (offWords + lWords - 1) (hiMask nMod) (x .&. loMask nMod)
where
offBits = modWordSize off
offWords = divWordSize off
lWords = nWords (offBits + len)
{-# INLINE basicUnsafeCopy #-}
basicUnsafeCopy _ (BitMVec _ 0 _) = pure ()
basicUnsafeCopy (BitMVec offDst lenDst dst) (BitMVec offSrc _ src)
| offDstBits == 0, offSrcBits == 0 = case modWordSize lenDst of
0 -> copyMutableByteArray dst
(wordsToBytes offDstWords)
src
(wordsToBytes offSrcWords)
(wordsToBytes lDstWords)
nMod -> do
copyMutableByteArray dst
(wordsToBytes offDstWords)
src
(wordsToBytes offSrcWords)
(wordsToBytes $ lDstWords - 1)
lastWordSrc <- readByteArray src (offSrcWords + lDstWords - 1)
modifyByteArray dst (offDstWords + lDstWords - 1) (hiMask nMod) (lastWordSrc .&. loMask nMod)
where
offDstBits = modWordSize offDst
offDstWords = divWordSize offDst
lDstWords = nWords (offDstBits + lenDst)
offSrcBits = modWordSize offSrc
offSrcWords = divWordSize offSrc
basicUnsafeCopy (BitMVec offDst lenDst dst) (BitMVec offSrc _ src)
| offDstBits == offSrcBits = case modWordSize (offSrc + lenDst) of
0 -> do
firstWordSrc <- readByteArray src offSrcWords
modifyByteArray dst offDstWords (loMask offSrcBits) (firstWordSrc .&. hiMask offSrcBits)
copyMutableByteArray dst
(wordsToBytes $ offDstWords + 1)
src
(wordsToBytes $ offSrcWords + 1)
(wordsToBytes $ lDstWords - 1)
nMod -> if lDstWords == 1
then do
let lohiMask = loMask offSrcBits .|. hiMask nMod
theOnlyWordSrc <- readByteArray src offSrcWords
modifyByteArray dst offDstWords lohiMask (theOnlyWordSrc .&. complement lohiMask)
else do
firstWordSrc <- readByteArray src offSrcWords
modifyByteArray dst offDstWords (loMask offSrcBits) (firstWordSrc .&. hiMask offSrcBits)
copyMutableByteArray dst
(wordsToBytes $ offDstWords + 1)
src
(wordsToBytes $ offSrcWords + 1)
(wordsToBytes $ lDstWords - 2)
lastWordSrc <- readByteArray src (offSrcWords + lDstWords - 1)
modifyByteArray dst (offDstWords + lDstWords - 1) (hiMask nMod) (lastWordSrc .&. loMask nMod)
where
offDstBits = modWordSize offDst
offDstWords = divWordSize offDst
lDstWords = nWords (offDstBits + lenDst)
offSrcBits = modWordSize offSrc
offSrcWords = divWordSize offSrc
basicUnsafeCopy dst@(BitMVec _ len _) src = do_copy 0
where
n = alignUp len
do_copy i
| i < n = do
x <- readWord src i
writeWord dst i x
do_copy (i + wordSize)
| otherwise = return ()
{-# INLINE basicUnsafeMove #-}
basicUnsafeMove !dst !src@(BitMVec srcShift srcLen _)
| MV.basicOverlaps dst src = do
srcCopy <- MV.drop (modWordSize srcShift)
<$> MV.basicUnsafeNew (modWordSize srcShift + srcLen)
MV.basicUnsafeCopy srcCopy src
MV.basicUnsafeCopy dst srcCopy
| otherwise = MV.basicUnsafeCopy dst src
{-# INLINE basicUnsafeSlice #-}
basicUnsafeSlice offset n (BitMVec off _ arr) = BitMVec (off + offset) n arr
{-# INLINE basicUnsafeGrow #-}
basicUnsafeGrow (BitMVec off len src) byBits
| byWords == 0 = pure $ BitMVec off (len + byBits) src
| otherwise = do
dst <- newByteArray (wordsToBytes newWords)
copyMutableByteArray dst 0 src 0 (wordsToBytes oldWords)
pure $ BitMVec off (len + byBits) dst
where
oldWords = nWords (off + len)
newWords = nWords (off + len + byBits)
byWords = newWords - oldWords
#ifndef BITVEC_THREADSAFE
unsafeFlipBit :: PrimMonad m => U.MVector (PrimState m) Bit -> Int -> m ()
unsafeFlipBit (BitMVec off _ arr) !i' = do
let i = off + i'
j = divWordSize i
k = modWordSize i
kk = 1 `unsafeShiftL` k :: Word
word <- readByteArray arr j
writeByteArray arr j (word `xor` kk)
{-# INLINE unsafeFlipBit #-}
flipBit :: PrimMonad m => U.MVector (PrimState m) Bit -> Int -> m ()
flipBit v i =
BOUNDS_CHECK(checkIndex) "flipBit" i (MV.length v) $ unsafeFlipBit v i
{-# INLINE flipBit #-}
#else
unsafeFlipBit :: PrimMonad m => U.MVector (PrimState m) Bit -> Int -> m ()
unsafeFlipBit (BitMVec off _ (MutableByteArray mba)) !i' = do
let i = off + i'
!(I# j) = divWordSize i
!(I# k) = 1 `unsafeShiftL` modWordSize i
primitive $ \state ->
let !(# state', _ #) = fetchXorIntArray# mba j k state in (# state', () #)
{-# INLINE unsafeFlipBit #-}
flipBit :: PrimMonad m => U.MVector (PrimState m) Bit -> Int -> m ()
flipBit v i =
BOUNDS_CHECK(checkIndex) "flipBit" i (MV.length v) $ unsafeFlipBit v i
{-# INLINE flipBit #-}
#endif
instance V.Vector U.Vector Bit where
basicUnsafeFreeze (BitMVec s n v) =
liftM (BitVec s n) (unsafeFreezeByteArray v)
basicUnsafeThaw (BitVec s n v) = liftM (BitMVec s n) (unsafeThawByteArray v)
basicLength (BitVec _ n _) = n
basicUnsafeIndexM (BitVec off _ arr) !i' = do
let i = off + i'
pure $! readBit (modWordSize i) (indexByteArray arr (divWordSize i))
basicUnsafeCopy dst src = do
src1 <- V.basicUnsafeThaw src
MV.basicUnsafeCopy dst src1
{-# INLINE basicUnsafeSlice #-}
basicUnsafeSlice offset n (BitVec off _ arr) = BitVec (off + offset) n arr