module Data.Graph.Haggle.Internal.BitSet (
  BitSet,
  newBitSet,
  setBit,
  testBit
  ) where

import Control.Monad.ST
import qualified Data.Bits as B
import Data.Vector.Unboxed.Mutable ( STVector )
import qualified Data.Vector.Unboxed.Mutable as V
import Data.Word ( Word64 )

data BitSet s = BS (STVector s Word64) {-# UNPACK #-} !Int

bitsPerWord :: Int
bitsPerWord :: Int
bitsPerWord = Int
64

-- | Allocate a new 'BitSet' with @n@ bits.  Bits are all
-- initialized to zero.
--
-- > bs <- newBitSet n
newBitSet :: Int -> ST s (BitSet s)
newBitSet :: forall s. Int -> ST s (BitSet s)
newBitSet Int
n = do
  let nWords :: Int
nWords = (Int
n forall a. Integral a => a -> a -> a
`div` Int
bitsPerWord) forall a. Num a => a -> a -> a
+ Int
1
  STVector s Word64
v <- forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
V.replicate Int
nWords Word64
0
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s. STVector s Word64 -> Int -> BitSet s
BS STVector s Word64
v Int
n

-- | Set a bit in the bitset.  Out of range has no effect.
setBit :: BitSet s -> Int -> ST s ()
setBit :: forall s. BitSet s -> Int -> ST s ()
setBit (BS STVector s Word64
v Int
sz) Int
bitIx
  | Int
bitIx forall a. Ord a => a -> a -> Bool
>= Int
sz = forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do
    let wordIx :: Int
wordIx = Int
bitIx forall a. Integral a => a -> a -> a
`div` Int
bitsPerWord
        bitPos :: Int
bitPos = Int
bitIx forall a. Integral a => a -> a -> a
`mod` Int
bitsPerWord
    Word64
oldWord <- forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
V.read STVector s Word64
v Int
wordIx
    let newWord :: Word64
newWord = forall a. Bits a => a -> Int -> a
B.setBit Word64
oldWord Int
bitPos
    forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
V.write STVector s Word64
v Int
wordIx Word64
newWord

-- | Return True if the bit is set.  Out of range will return False.
testBit :: BitSet s -> Int -> ST s Bool
testBit :: forall s. BitSet s -> Int -> ST s Bool
testBit (BS STVector s Word64
v Int
sz) Int
bitIx
  | Int
bitIx forall a. Ord a => a -> a -> Bool
>= Int
sz = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  | Bool
otherwise = do
    let wordIx :: Int
wordIx = Int
bitIx forall a. Integral a => a -> a -> a
`div` Int
bitsPerWord
        bitPos :: Int
bitPos = Int
bitIx forall a. Integral a => a -> a -> a
`mod` Int
bitsPerWord
    Word64
w <- forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
V.read STVector s Word64
v Int
wordIx
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> Bool
B.testBit Word64
w Int
bitPos