{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
#ifndef BITVEC_THREADSAFE
module Data.Bit.Mutable
#else
module Data.Bit.MutableTS
#endif
( castFromWordsM
, castToWordsM
, cloneToWordsM
, cloneToWords8M
, zipInPlace
, invertInPlace
, selectBitsInPlace
, excludeBitsInPlace
, reverseInPlace
) where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
#ifndef BITVEC_THREADSAFE
import Data.Bit.Internal
#else
import Data.Bit.InternalTS
#endif
import Data.Bit.Utils
import Data.Bits
import Data.Primitive.ByteArray
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import Data.Word
castFromWordsM :: MVector s Word -> MVector s Bit
castFromWordsM (MU.MV_Word (P.MVector off len ws)) =
BitMVec (mulWordSize off) (mulWordSize len) ws
castToWordsM :: MVector s Bit -> Maybe (MVector s Word)
castToWordsM (BitMVec s n ws)
| aligned s, aligned n = Just $ MU.MV_Word $ P.MVector (divWordSize s)
(divWordSize n)
ws
| otherwise = Nothing
cloneToWordsM
:: PrimMonad m
=> MVector (PrimState m) Bit
-> m (MVector (PrimState m) Word)
cloneToWordsM v = do
let lenBits = MU.length v
lenWords = nWords lenBits
w@(BitMVec _ _ arr) <- MU.unsafeNew (mulWordSize lenWords)
MU.unsafeCopy (MU.slice 0 lenBits w) v
MU.set (MU.slice lenBits (mulWordSize lenWords - lenBits) w) (Bit False)
pure $ MU.MV_Word $ P.MVector 0 lenWords arr
{-# INLINE cloneToWordsM #-}
cloneToWords8M
:: PrimMonad m
=> MVector (PrimState m) Bit
-> m (MVector (PrimState m) Word8)
cloneToWords8M v = do
let lenBits = MU.length v
lenWords = (lenBits + 7) `shiftR` 3
w@(BitMVec _ _ arr) <- MU.unsafeNew (lenWords `shiftL` 3)
MU.unsafeCopy (MU.slice 0 lenBits w) v
MU.set (MU.slice lenBits (lenWords `shiftL` 3 - lenBits) w) (Bit False)
pure $ MU.MV_Word8 $ P.MVector 0 lenWords arr
{-# INLINE cloneToWords8M #-}
zipInPlace
:: forall m.
PrimMonad m
=> (forall a . Bits a => a -> a -> a)
-> Vector Bit
-> MVector (PrimState m) Bit
-> m ()
zipInPlace f (BitVec off l xs) (BitMVec off' l' ys) =
go (l `min` l') off off'
where
go :: Int -> Int -> Int -> m ()
go len offXs offYs
| shft == 0 =
go' len offXs (divWordSize offYs)
| len <= wordSize = do
y <- readWord vecYs 0
writeWord vecYs 0 (f x y)
| otherwise = do
y <- readByteArray ys base
modifyByteArray ys base (loMask shft) (f (x `unsafeShiftL` shft) y .&. hiMask shft)
go' (len - wordSize + shft) (offXs + wordSize - shft) (base + 1)
where
vecXs = BitVec offXs len xs
vecYs = BitMVec offYs len ys
x = indexWord vecXs 0
shft = modWordSize offYs
base = divWordSize offYs
go' :: Int -> Int -> Int -> m ()
go' len offXs offYsW = do
if shft == 0
then loopAligned offYsW
else loop offYsW (indexByteArray xs base)
when (modWordSize len /= 0) $ do
let ix = len - modWordSize len
let x = indexWord vecXs ix
y <- readWord vecYs ix
writeWord vecYs ix (f x y)
where
vecXs = BitVec offXs len xs
vecYs = BitMVec (mulWordSize offYsW) len ys
shft = modWordSize offXs
shft' = wordSize - shft
base = divWordSize offXs
base0 = base - offYsW
base1 = base0 + 1
iMax = divWordSize len + offYsW
loopAligned :: Int -> m ()
loopAligned !i
| i >= iMax = pure ()
| otherwise = do
let x = indexByteArray xs (base0 + i) :: Word
y <- readByteArray ys i
writeByteArray ys i (f x y)
loopAligned (i + 1)
loop :: Int -> Word -> m ()
loop !i !acc
| i >= iMax = pure ()
| otherwise = do
let accNew = indexByteArray xs (base1 + i)
x = (acc `unsafeShiftR` shft) .|. (accNew `unsafeShiftL` shft')
y <- readByteArray ys i
writeByteArray ys i (f x y)
loop (i + 1) accNew
#if __GLASGOW_HASKELL__ >= 800
{-# SPECIALIZE zipInPlace :: (forall a. Bits a => a -> a -> a) -> Vector Bit -> MVector s Bit -> ST s () #-}
#endif
{-# INLINE zipInPlace #-}
invertInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
invertInPlace xs = do
let n = MU.length xs
forM_ [0, wordSize .. n - 1] $ \i -> do
x <- readWord xs i
writeWord xs i (complement x)
#if __GLASGOW_HASKELL__ >= 800
{-# SPECIALIZE invertInPlace :: U.MVector s Bit -> ST s () #-}
#endif
selectBitsInPlace
:: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
selectBitsInPlace is xs = loop 0 0
where
!n = min (U.length is) (MU.length xs)
loop !i !ct
| i >= n = return ct
| otherwise = do
x <- readWord xs i
let !(nSet, x') = selectWord (masked (n - i) (indexWord is i)) x
writeWord xs ct x'
loop (i + wordSize) (ct + nSet)
excludeBitsInPlace
:: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
excludeBitsInPlace is xs = loop 0 0
where
!n = min (U.length is) (MU.length xs)
loop !i !ct
| i >= n = return ct
| otherwise = do
x <- readWord xs i
let !(nSet, x') =
selectWord (masked (n - i) (complement (indexWord is i))) x
writeWord xs ct x'
loop (i + wordSize) (ct + nSet)
reverseInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
reverseInPlace xs | len == 0 = pure ()
| otherwise = loop 0
where
len = MU.length xs
loop !i
| i' <= j' = do
x <- readWord xs i
y <- readWord xs j'
writeWord xs i (reverseWord y)
writeWord xs j' (reverseWord x)
loop i'
| i' < j = do
let w = (j - i) `shiftR` 1
k = j - w
x <- readWord xs i
y <- readWord xs k
writeWord xs i (meld w (reversePartialWord w y) x)
writeWord xs k (meld w (reversePartialWord w x) y)
loop i'
| otherwise = do
let w = j - i
x <- readWord xs i
writeWord xs i (meld w (reversePartialWord w x) x)
where
!j = len - i
!i' = i + wordSize
!j' = j - wordSize
#if __GLASGOW_HASKELL__ >= 800
{-# SPECIALIZE reverseInPlace :: U.MVector s Bit -> ST s () #-}
#endif