{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE CPP                        #-}
{-# LANGUAGE MagicHash                  #-}

module Data.Bit.Utils
  ( lgWordSize
  , modWordSize
  , divWordSize
  , mulWordSize
  , wordSize
  , wordsToBytes
  , nWords
  , aligned
  , alignUp
  , selectWord
  , reverseWord
  , reversePartialWord
  , masked
  , meld
  , ffs
  , loMask
  , hiMask
  , sparseBits
  , fromPrimVector
  , toPrimVector
  ) where

#include "MachDeps.h"

import Data.Bits
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Unboxed as U
#if __GLASGOW_HASKELL__ >= 810
import GHC.Exts
#endif
import Unsafe.Coerce

import Data.Bit.PdepPext

-- |The number of bits in a 'Word'.  A handy constant to have around when defining 'Word'-based bulk operations on bit vectors.
wordSize :: Int
wordSize :: Int
wordSize = Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word)

lgWordSize :: Int
lgWordSize :: Int
lgWordSize = case Int
wordSize of
  Int
32 -> Int
5
  Int
64 -> Int
6
  Int
_  -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"lgWordSize: unknown architecture"

wordSizeMask :: Int
wordSizeMask :: Int
wordSizeMask = Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

wordSizeMaskC :: Int
wordSizeMaskC :: Int
wordSizeMaskC = Int -> Int
forall a. Bits a => a -> a
complement Int
wordSizeMask

divWordSize :: Bits a => a -> a
divWordSize :: a -> a
divWordSize a
x = a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR a
x Int
lgWordSize
{-# INLINE divWordSize #-}

modWordSize :: Int -> Int
modWordSize :: Int -> Int
modWordSize Int
x = Int
x Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
{-# INLINE modWordSize #-}

mulWordSize :: Bits a => a -> a
mulWordSize :: a -> a
mulWordSize a
x = a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftL a
x Int
lgWordSize
{-# INLINE mulWordSize #-}

-- number of words needed to store n bits
nWords :: Int -> Int
nWords :: Int -> Int
nWords Int
ns = Int -> Int
forall a. Bits a => a -> a
divWordSize (Int
ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

wordsToBytes :: Int -> Int
wordsToBytes :: Int -> Int
wordsToBytes Int
ns = case Int
wordSize of
  Int
32 -> Int
ns Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2
  Int
64 -> Int
ns Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
3
  Int
_  -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"wordsToBytes: unknown architecture"

aligned :: Int -> Bool
aligned :: Int -> Bool
aligned Int
x = Int
x Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
wordSizeMask Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0

-- round a number of bits up to the nearest multiple of word size
alignUp :: Int -> Int
alignUp :: Int -> Int
alignUp Int
x | Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x'   = Int
x'
          | Bool
otherwise = Int
x' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
wordSize
  where x' :: Int
x' = Int -> Int
alignDown Int
x

-- round a number of bits down to the nearest multiple of word size
alignDown :: Int -> Int
alignDown :: Int -> Int
alignDown Int
x = Int
x Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
wordSizeMaskC

-- create a mask consisting of the lower n bits
mask :: Int -> Word
mask :: Int -> Word
mask Int
b
  | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
wordSize = Word -> Word
forall a. Bits a => a -> a
complement Word
0
  | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0         = Word
0
  | Bool
otherwise     = Int -> Word
forall a. Bits a => Int -> a
bit Int
b Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
1

masked :: Int -> Word -> Word
masked :: Int -> Word -> Word
masked Int
b Word
x = Word
x Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Int -> Word
mask Int
b

-- meld 2 words by taking the low 'b' bits from 'lo' and the rest from 'hi'
meld :: Int -> Word -> Word -> Word
meld :: Int -> Word -> Word -> Word
meld Int
b Word
lo Word
hi = (Word
lo Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
m) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.|. (Word
hi Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word -> Word
forall a. Bits a => a -> a
complement Word
m) where m :: Word
m = Int -> Word
mask Int
b
{-# INLINE meld #-}

#if __GLASGOW_HASKELL__ >= 810
reverseWord :: Word -> Word
reverseWord :: Word -> Word
reverseWord (W# Word#
w#) = Word# -> Word
W# (Word# -> Word#
bitReverse# Word#
w#)
#elif WORD_SIZE_IN_BITS == 64
reverseWord :: Word -> Word
reverseWord x0 = x6
 where
  x1 = ((x0 .&. 0x5555555555555555) `shiftL`  1) .|. ((x0 .&. 0xAAAAAAAAAAAAAAAA) `shiftR`  1)
  x2 = ((x1 .&. 0x3333333333333333) `shiftL`  2) .|. ((x1 .&. 0xCCCCCCCCCCCCCCCC) `shiftR`  2)
  x3 = ((x2 .&. 0x0F0F0F0F0F0F0F0F) `shiftL`  4) .|. ((x2 .&. 0xF0F0F0F0F0F0F0F0) `shiftR`  4)
  x4 = ((x3 .&. 0x00FF00FF00FF00FF) `shiftL`  8) .|. ((x3 .&. 0xFF00FF00FF00FF00) `shiftR`  8)
  x5 = ((x4 .&. 0x0000FFFF0000FFFF) `shiftL` 16) .|. ((x4 .&. 0xFFFF0000FFFF0000) `shiftR` 16)
  x6 = ((x5 .&. 0x00000000FFFFFFFF) `shiftL` 32) .|. ((x5 .&. 0xFFFFFFFF00000000) `shiftR` 32)
#elif WORD_SIZE_IN_BITS == 32
reverseWord :: Word -> Word
reverseWord x0 = x5
 where
  x1 = ((x0 .&. 0x55555555) `shiftL`  1) .|. ((x0 .&. 0xAAAAAAAA) `shiftR`  1)
  x2 = ((x1 .&. 0x33333333) `shiftL`  2) .|. ((x1 .&. 0xCCCCCCCC) `shiftR`  2)
  x3 = ((x2 .&. 0x0F0F0F0F) `shiftL`  4) .|. ((x2 .&. 0xF0F0F0F0) `shiftR`  4)
  x4 = ((x3 .&. 0x00FF00FF) `shiftL`  8) .|. ((x3 .&. 0xFF00FF00) `shiftR`  8)
  x5 = ((x4 .&. 0x0000FFFF) `shiftL` 16) .|. ((x4 .&. 0xFFFF0000) `shiftR` 16)
#else
#error unsupported WORD_SIZE_IN_BITS config
#endif

reversePartialWord :: Int -> Word -> Word
reversePartialWord :: Int -> Word -> Word
reversePartialWord Int
n Word
w
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
wordSize = Word -> Word
reverseWord Word
w
  | Bool
otherwise     = Word -> Word
reverseWord Word
w Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` (Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)

ffs :: Word -> Maybe Int
ffs :: Word -> Maybe Int
ffs Word
0 = Maybe Int
forall a. Maybe a
Nothing
ffs Word
x = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$! (Word -> Int
forall a. Bits a => a -> Int
popCount (Word
x Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Word -> Word
forall a. Bits a => a -> a
complement (-Word
x)) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
{-# INLINE ffs #-}

selectWord :: Word -> Word -> (Int, Word)
selectWord :: Word -> Word -> (Int, Word)
selectWord Word
msk Word
src = (Word -> Int
forall a. Bits a => a -> Int
popCount Word
msk, Word -> Word -> Word
pext Word
src Word
msk)
{-# INLINE selectWord #-}

#if WORD_SIZE_IN_BITS == 64

-- | Insert 0 between each consecutive bits of an input.
-- xyzw --> (x0y0, z0w0)
sparseBits :: Word -> (Word, Word)
sparseBits :: Word -> (Word, Word)
sparseBits Word
w = (Word
x, Word
y)
  where
    x :: Word
x = Word -> Word
sparseBitsInternal (Word
w Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Int -> Word
loMask Int
32)
    y :: Word
y = Word -> Word
sparseBitsInternal (Word
w Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
32)

sparseBitsInternal :: Word -> Word
sparseBitsInternal :: Word -> Word
sparseBitsInternal Word
x = Word
x4
  where
    t :: Word
t  = (Word
x  Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
x  Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x00000000ffff0000
    x0 :: Word
x0 = Word
x  Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t  Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t  Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftL` Int
16));

    t0 :: Word
t0 = (Word
x0 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
x0 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x0000ff000000ff00;
    x1 :: Word
x1 = Word
x0 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t0 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t0 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftL` Int
8));
    t1 :: Word
t1 = (Word
x1 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
x1 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
4)) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x00f000f000f000f0;
    x2 :: Word
x2 = Word
x1 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t1 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t1 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftL` Int
4));
    t2 :: Word
t2 = (Word
x2 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
x2 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
2)) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x0c0c0c0c0c0c0c0c;
    x3 :: Word
x3 = Word
x2 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t2 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t2 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftL` Int
2));
    t3 :: Word
t3 = (Word
x3 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
x3 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x2222222222222222;
    x4 :: Word
x4 = Word
x3 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t3 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` (Word
t3 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftL` Int
1));

#elif WORD_SIZE_IN_BITS == 32

-- | Insert 0 between each consecutive bits of an input.
-- xyzw --> (x0y0, z0w0)
sparseBits :: Word -> (Word, Word)
sparseBits w = (x, y)
  where
    x = sparseBitsInternal (w .&. loMask 16)
    y = sparseBitsInternal (w `shiftR` 16)

sparseBitsInternal :: Word -> Word
sparseBitsInternal x0 = x4
  where
    t0 = (x0 `xor` (x0 `shiftR` 8)) .&. 0x0000ff00;
    x1 = x0 `xor` (t0 `xor` (t0 `shiftL` 8));
    t1 = (x1 `xor` (x1 `shiftR` 4)) .&. 0x00f000f0;
    x2 = x1 `xor` (t1 `xor` (t1 `shiftL` 4));
    t2 = (x2 `xor` (x2 `shiftR` 2)) .&. 0x0c0c0c0c;
    x3 = x2 `xor` (t2 `xor` (t2 `shiftL` 2));
    t3 = (x3 `xor` (x3 `shiftR` 1)) .&. 0x22222222;
    x4 = x3 `xor` (t3 `xor` (t3 `shiftL` 1));

#else
#error unsupported WORD_SIZE_IN_BITS config
#endif

loMask :: Int -> Word
loMask :: Int -> Word
loMask Int
n = Word
1 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
n Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
1
{-# INLINE loMask #-}

hiMask :: Int -> Word
hiMask :: Int -> Word
hiMask Int
n = Word -> Word
forall a. Bits a => a -> a
complement (Word
1 Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
n Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
1)
{-# INLINE hiMask #-}

fromPrimVector :: P.Vector Word -> U.Vector Word
fromPrimVector :: Vector Word -> Vector Word
fromPrimVector = Vector Word -> Vector Word
forall a b. a -> b
unsafeCoerce
{-# INLINE fromPrimVector #-}

toPrimVector :: U.Vector Word -> P.Vector Word
toPrimVector :: Vector Word -> Vector Word
toPrimVector = Vector Word -> Vector Word
forall a b. a -> b
unsafeCoerce
{-# INLINE toPrimVector #-}