{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
#ifndef BITVEC_THREADSAFE
module Data.Bit.F2Poly
#else
module Data.Bit.F2PolyTS
#endif
( F2Poly
, unF2Poly
, toF2Poly
) where
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.ST
#ifndef BITVEC_THREADSAFE
import Data.Bit.Immutable
import Data.Bit.Internal
import Data.Bit.Mutable
#else
import Data.Bit.ImmutableTS
import Data.Bit.InternalTS
import Data.Bit.MutableTS
#endif
import Data.Bit.Utils
import Data.Bits
import Data.Coerce
import Data.List hiding (dropWhileEnd)
import Data.Typeable
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import GHC.Generics
#if UseIntegerGmp
import Data.Primitive.ByteArray
import qualified Data.Vector.Primitive as P
import GHC.Exts
import GHC.Integer.GMP.Internals
import GHC.Integer.Logarithms
import Unsafe.Coerce
#endif
newtype F2Poly = F2Poly {
unF2Poly :: U.Vector Bit
}
deriving (Eq, Ord, Show, Typeable, Generic, NFData)
toF2Poly :: U.Vector Bit -> F2Poly
toF2Poly xs = F2Poly $ dropWhileEnd $ castFromWords $ cloneToWords xs
instance Num F2Poly where
(+) = coerce xorBits
(-) = coerce xorBits
negate = id
abs = id
signum = const (F2Poly (U.singleton (Bit True)))
(*) = coerce ((dropWhileEnd .) . karatsuba)
#if UseIntegerGmp
fromInteger !n = case n of
S# i# -> F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#))))
$ fromBigNat $ wordToBigNat (int2Word# i#)
Jp# bn# -> F2Poly $ BitVec 0 (I# (integerLog2# n) + 1) $ fromBigNat bn#
Jn#{} -> error "F2Poly.fromInteger: argument must be non-negative"
#else
fromInteger = F2Poly . dropWhileEnd . integerToBits
#endif
instance Enum F2Poly where
fromEnum = fromIntegral
#if UseIntegerGmp
toEnum !(I# i#) = F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#))))
$ fromBigNat $ wordToBigNat (int2Word# i#)
#else
toEnum = fromIntegral
#endif
instance Real F2Poly where
toRational = fromIntegral
instance Integral F2Poly where
toInteger = bitsToInteger . unF2Poly
quotRem (F2Poly xs) (F2Poly ys) = (F2Poly (dropWhileEnd qs), F2Poly (dropWhileEnd rs))
where
(qs, rs) = quotRemBits xs ys
rem = coerce ((dropWhileEnd .) . remBits)
divMod = quotRem
mod = rem
xorBits
:: U.Vector Bit
-> U.Vector Bit
-> U.Vector Bit
#if UseIntegerGmp
xorBits (BitVec _ 0 _) ys = ys
xorBits xs (BitVec _ 0 _) = xs
xorBits (BitVec 0 lx xarr) (BitVec 0 ly yarr) = case lx `compare` ly of
LT -> BitVec 0 ly zs
EQ -> dropWhileEnd $ BitVec 0 (lx `min` (sizeofByteArray zs `shiftL` 3)) zs
GT -> BitVec 0 lx zs
where
zs = fromBigNat (toBigNat xarr `xorBigNat` toBigNat yarr)
#endif
xorBits xs ys = dropWhileEnd $ runST $ do
let lx = U.length xs
ly = U.length ys
(shorterLen, longerLen, longer) = if lx >= ly then (ly, lx, xs) else (lx, ly, ys)
zs <- MU.replicate longerLen (Bit False)
forM_ [0, wordSize .. shorterLen - 1] $ \i ->
writeWord zs i (indexWord xs i `xor` indexWord ys i)
U.unsafeCopy (MU.drop shorterLen zs) (U.drop shorterLen longer)
U.unsafeFreeze zs
karatsubaThreshold :: Int
karatsubaThreshold = 4096
karatsuba :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
karatsuba xs ys
| xs == ys = sqrBits xs
| lenXs <= karatsubaThreshold || lenYs <= karatsubaThreshold
= mulBits xs ys
| otherwise = runST $ do
zs <- MU.unsafeNew lenZs
forM_ [0, wordSize .. lenZs - 1] $ \k -> do
let z0 = indexWord0 zs0 k
z11 = indexWord0 zs11 (k - m)
z10 = indexWord0 zs0 (k - m)
z12 = indexWord0 zs2 (k - m)
z2 = indexWord0 zs2 (k - 2 * m)
writeWord zs k (z0 `xor` z11 `xor` z10 `xor` z12 `xor` z2)
U.unsafeFreeze zs
where
lenXs = U.length xs
lenYs = U.length ys
lenZs = lenXs + lenYs - 1
m' = ((lenXs `min` lenYs) + 1) `quot` 2
m = if karatsubaThreshold < wordSize then m' else m' - modWordSize m'
xs0 = U.slice 0 m xs
xs1 = U.slice m (lenXs - m) xs
ys0 = U.slice 0 m ys
ys1 = U.slice m (lenYs - m) ys
xs01 = xorBits xs0 xs1
ys01 = xorBits ys0 ys1
zs0 = karatsuba xs0 ys0
zs2 = karatsuba xs1 ys1
zs11 = karatsuba xs01 ys01
indexWord0 :: U.Vector Bit -> Int -> Word
indexWord0 bv i
| i <= - wordSize = 0
| lenI <= 0 = 0
| i < 0, lenI >= wordSize = word0
| i < 0 = word0 .&. loMask lenI
| lenI >= wordSize = word
| otherwise = word .&. loMask lenI
where
lenI = U.length bv - i
word = indexWord bv i
word0 = indexWord bv 0 `unsafeShiftL` (- i)
mulBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
mulBits xs ys
| lenXs == 0 || lenYs == 0 = U.empty
| lenXs <= wordSize && lenYs <= wordSize = mulShortShort x0 y0
| lenYs <= wordSize = mulLongShort xs y0
| lenXs <= wordSize = mulLongShort ys x0
| otherwise = runST $ do
zs <- MU.replicate lenZs (Bit False)
forM_ [0 .. lenYs - 1] $ \k ->
MU.unsafeWrite zs k
(zipAndCountParityBits xs (U.unsafeSlice (lenYs - 1 - k) (k + 1) rys))
forM_ [lenYs .. lenZs - 1] $ \k ->
MU.unsafeWrite zs k
(zipAndCountParityBits (U.unsafeSlice (k - (lenYs - 1)) (lenXs + lenYs + 1 - k) xs) rys)
U.unsafeFreeze zs
where
lenXs = U.length xs
lenYs = U.length ys
lenZs = lenXs + lenYs - 1
rys = reverseBits ys
x0 = indexWord xs 0 .&. loMask lenXs
y0 = indexWord ys 0 .&. loMask lenYs
mulShortShort :: Word -> Word -> U.Vector Bit
mulShortShort xs ys = runST $ do
zs <- MU.replicate lenZs (Bit False)
forM_ [0 .. lenYs - 1] $ \k -> do
let yk = rys `shiftR` (lenYs - 1 - k)
l = (k + 1) `min` lenXs
MU.unsafeWrite zs k (fromIntegral $ popCount $ xs .&. yk .&. loMask l)
forM_ [lenYs .. lenZs - 1] $ \k -> do
let xk = xs `shiftR` (k - (lenYs - 1))
l = (lenXs + lenYs + 1 - k) `min` lenYs
MU.unsafeWrite zs k (fromIntegral $ popCount $ xk .&. rys .&. loMask l)
U.unsafeFreeze zs
where
clzXs = countLeadingZeros xs
lenXs = wordSize - clzXs
clzYs = countLeadingZeros ys
lenYs = wordSize - clzYs
lenZs = lenXs + lenYs - 1
rys = reverseWord (ys `shiftL` clzYs)
mulLongShort :: U.Vector Bit -> Word -> U.Vector Bit
mulLongShort xs ys = runST $ do
zs <- MU.replicate lenZs (Bit False)
forM_ [0 .. lenYs - 1] $ \k -> do
let yk = rys `shiftR` (lenYs - 1 - k)
l = (k + 1) `min` lenXs
MU.unsafeWrite zs k (fromIntegral $ popCount $ x0 .&. yk .&. loMask l)
forM_ [lenYs .. lenZs - 1] $ \k -> do
let xk = indexWord xs (k - (lenYs - 1))
l = (lenXs + lenYs + 1 - k) `min` lenYs
MU.unsafeWrite zs k (fromIntegral $ popCount $ xk .&. rys .&. loMask l)
U.unsafeFreeze zs
where
lenXs = U.length xs
clzYs = countLeadingZeros ys
lenYs = wordSize - clzYs
lenZs = lenXs + lenYs - 1
rys = reverseWord (ys `shiftL` clzYs)
x0 = indexWord xs 0
zipAndCountParityBits :: U.Vector Bit -> U.Vector Bit -> Bit
zipAndCountParityBits xs ys
| nMod == 0 = fromIntegral $ popCnt
| otherwise = fromIntegral $ popCnt `xor` lastPopCnt
where
n = min (U.length xs) (U.length ys)
nMod = modWordSize n
ff i = indexWord xs i .&. indexWord ys i
popCnt = foldl' (\acc i -> acc `xor` popCount (ff i)) 0 [0, wordSize .. n - nMod - 1]
lastPopCnt = popCount (ff (n - nMod) .&. loMask nMod)
sqrBits :: U.Vector Bit -> U.Vector Bit
sqrBits xs = runST $ do
let lenXs = U.length xs
zs <- MU.replicate (lenXs `shiftL` 1) (Bit False)
forM_ [0, wordSize .. lenXs - 1] $ \i -> do
let (z0, z1) = sparseBits (indexWord xs i)
writeWord zs (i `shiftL` 1) z0
writeWord zs (i `shiftL` 1 + wordSize) z1
U.unsafeFreeze zs
quotRemBits :: U.Vector Bit -> U.Vector Bit -> (U.Vector Bit, U.Vector Bit)
quotRemBits xs ys
| U.null ys = throw DivideByZero
| U.length xs < U.length ys = (U.empty, xs)
| otherwise = runST $ do
let lenXs = U.length xs
lenYs = U.length ys
lenQs = lenXs - lenYs + 1
qs <- MU.replicate lenQs (Bit False)
rs <- MU.replicate lenXs (Bit False)
U.unsafeCopy rs xs
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
Bit r <- MU.unsafeRead rs (lenYs - 1 + i)
when r $ do
MU.unsafeWrite qs i (Bit True)
zipInPlace xor ys (MU.drop i rs)
let rs' = MU.unsafeSlice 0 lenYs rs
(,) <$> U.unsafeFreeze qs <*> U.unsafeFreeze rs'
remBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
remBits xs ys
| U.null ys = throw DivideByZero
| U.length xs < U.length ys = xs
| otherwise = runST $ do
let lenXs = U.length xs
lenYs = U.length ys
lenQs = lenXs - lenYs + 1
rs <- MU.replicate lenXs (Bit False)
U.unsafeCopy rs xs
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
Bit r <- MU.unsafeRead rs (lenYs - 1 + i)
when r $ do
zipInPlace xor ys (MU.drop i rs)
let rs' = MU.unsafeSlice 0 lenYs rs
U.unsafeFreeze rs'
dropWhileEnd
:: U.Vector Bit
-> U.Vector Bit
dropWhileEnd xs = U.unsafeSlice 0 (go (U.length xs)) xs
where
go n
| n < wordSize = wordSize - countLeadingZeros (indexWord xs 0 .&. loMask n)
| otherwise = case indexWord xs (n - wordSize) of
0 -> go (n - wordSize)
w -> n - countLeadingZeros w
#if UseIntegerGmp
bitsToByteArray :: U.Vector Bit -> ByteArray#
bitsToByteArray xs = arr
where
ys = if U.null xs then U.singleton 0 else cloneToWords xs
!(P.Vector _ _ (ByteArray arr)) = toPrimVector ys
fromBigNat :: BigNat -> ByteArray
fromBigNat = unsafeCoerce
toBigNat :: ByteArray -> BigNat
toBigNat = unsafeCoerce
bitsToInteger :: U.Vector Bit -> Integer
bitsToInteger xs = bigNatToInteger (BN# (bitsToByteArray xs))
#else
integerToBits :: Integer -> U.Vector Bit
integerToBits x = U.generate (bitLen x) (Bit . testBit x)
bitLen :: Integer -> Int
bitLen x
= fst
$ head
$ dropWhile (\(_, b) -> x >= b)
$ map (\a -> (a, 1 `shiftL` a))
$ map (1 `shiftL`)
$ [0..]
bitsToInteger :: U.Vector Bit -> Integer
bitsToInteger = U.ifoldl' (\acc i (Bit b) -> if b then acc `setBit` i else acc) 0
#endif