{-# LANGUAGE CPP                        #-}

{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE DeriveDataTypeable         #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase                 #-}
{-# LANGUAGE MagicHash                  #-}
{-# LANGUAGE RankNTypes                 #-}

#ifndef BITVEC_THREADSAFE
module Data.Bit.F2Poly
#else
module Data.Bit.F2PolyTS
#endif
  ( F2Poly
  , unF2Poly
  , toF2Poly
  , gcdExt
  ) where

import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.ST
#ifndef BITVEC_THREADSAFE
import Data.Bit.Immutable
import Data.Bit.Internal
import Data.Bit.Mutable
#else
import Data.Bit.ImmutableTS
import Data.Bit.InternalTS
import Data.Bit.MutableTS
#endif
import Data.Bit.Utils
import Data.Bits
import Data.Char
import Data.Coerce
import Data.Primitive.ByteArray
import Data.Typeable
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import GHC.Exts
import GHC.Generics
import Numeric

#ifdef MIN_VERSION_ghc_bignum
import GHC.Num.BigNat
import GHC.Num.Integer
#else
import GHC.Integer.GMP.Internals
import GHC.Integer.Logarithms
#endif

-- | Binary polynomials of one variable, backed
-- by an unboxed 'Data.Vector.Unboxed.Vector' 'Bit'.
--
-- Polynomials are stored normalized, without leading zero coefficients.
--
-- 'Ord' instance does not make much sense mathematically,
-- it is defined only for the sake of 'Data.Set.Set', 'Data.Map.Map', etc.
--
-- >>> :set -XBinaryLiterals
-- >>> -- (1 + x) (1 + x + x^2) = 1 + x^3 (mod 2)
-- >>> 0b11 * 0b111 :: F2Poly
-- 0b1001
newtype F2Poly = F2Poly {
  F2Poly -> Vector Bit
unF2Poly :: U.Vector Bit
  -- ^ Convert 'F2Poly' to a vector of coefficients
  -- (first element corresponds to a constant term).
  --
  -- >>> :set -XBinaryLiterals
  -- >>> unF2Poly 0b1101
  -- [1,0,1,1]
  }
  deriving (F2Poly -> F2Poly -> Bool
(F2Poly -> F2Poly -> Bool)
-> (F2Poly -> F2Poly -> Bool) -> Eq F2Poly
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: F2Poly -> F2Poly -> Bool
$c/= :: F2Poly -> F2Poly -> Bool
== :: F2Poly -> F2Poly -> Bool
$c== :: F2Poly -> F2Poly -> Bool
Eq, Eq F2Poly
Eq F2Poly
-> (F2Poly -> F2Poly -> Ordering)
-> (F2Poly -> F2Poly -> Bool)
-> (F2Poly -> F2Poly -> Bool)
-> (F2Poly -> F2Poly -> Bool)
-> (F2Poly -> F2Poly -> Bool)
-> (F2Poly -> F2Poly -> F2Poly)
-> (F2Poly -> F2Poly -> F2Poly)
-> Ord F2Poly
F2Poly -> F2Poly -> Bool
F2Poly -> F2Poly -> Ordering
F2Poly -> F2Poly -> F2Poly
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: F2Poly -> F2Poly -> F2Poly
$cmin :: F2Poly -> F2Poly -> F2Poly
max :: F2Poly -> F2Poly -> F2Poly
$cmax :: F2Poly -> F2Poly -> F2Poly
>= :: F2Poly -> F2Poly -> Bool
$c>= :: F2Poly -> F2Poly -> Bool
> :: F2Poly -> F2Poly -> Bool
$c> :: F2Poly -> F2Poly -> Bool
<= :: F2Poly -> F2Poly -> Bool
$c<= :: F2Poly -> F2Poly -> Bool
< :: F2Poly -> F2Poly -> Bool
$c< :: F2Poly -> F2Poly -> Bool
compare :: F2Poly -> F2Poly -> Ordering
$ccompare :: F2Poly -> F2Poly -> Ordering
$cp1Ord :: Eq F2Poly
Ord, Typeable, (forall x. F2Poly -> Rep F2Poly x)
-> (forall x. Rep F2Poly x -> F2Poly) -> Generic F2Poly
forall x. Rep F2Poly x -> F2Poly
forall x. F2Poly -> Rep F2Poly x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep F2Poly x -> F2Poly
$cfrom :: forall x. F2Poly -> Rep F2Poly x
Generic, F2Poly -> ()
(F2Poly -> ()) -> NFData F2Poly
forall a. (a -> ()) -> NFData a
rnf :: F2Poly -> ()
$crnf :: F2Poly -> ()
NFData)

-- | Make 'F2Poly' from a list of coefficients
-- (first element corresponds to a constant term).
--
-- >>> :set -XOverloadedLists
-- >>> toF2Poly [1,0,1,1,0,0]
-- 0b1101
toF2Poly :: U.Vector Bit -> F2Poly
toF2Poly :: Vector Bit -> F2Poly
toF2Poly Vector Bit
xs = Vector Bit -> F2Poly
F2Poly (Vector Bit -> F2Poly) -> Vector Bit -> F2Poly
forall a b. (a -> b) -> a -> b
$ Vector Bit -> Vector Bit
dropWhileEnd (Vector Bit -> Vector Bit) -> Vector Bit -> Vector Bit
forall a b. (a -> b) -> a -> b
$ Vector Word -> Vector Bit
castFromWords (Vector Word -> Vector Bit) -> Vector Word -> Vector Bit
forall a b. (a -> b) -> a -> b
$ Vector Bit -> Vector Word
cloneToWords Vector Bit
xs

-- -- | Valid 'F2Poly' has offset 0 and no trailing garbage.
-- _isValid :: F2Poly -> Bool
-- _isValid (F2Poly (BitVec o l arr)) = o == 0 && l == l'
--   where
--     l' = U.length $ dropWhileEnd $ BitVec 0 (sizeofByteArray arr `shiftL` 3) arr

-- | Addition and multiplication are evaluated modulo 2.
--
-- 'abs' = 'id' and 'signum' = 'const' 1.
--
-- 'fromInteger' converts a binary polynomial, encoded as 'Integer',
-- to 'F2Poly' encoding.
instance Num F2Poly where
  + :: F2Poly -> F2Poly -> F2Poly
(+) = (Vector Bit -> Vector Bit -> Vector Bit)
-> F2Poly -> F2Poly -> F2Poly
coerce Vector Bit -> Vector Bit -> Vector Bit
xorBits
  (-) = (Vector Bit -> Vector Bit -> Vector Bit)
-> F2Poly -> F2Poly -> F2Poly
coerce Vector Bit -> Vector Bit -> Vector Bit
xorBits
  negate :: F2Poly -> F2Poly
negate = F2Poly -> F2Poly
forall a. a -> a
id
  abs :: F2Poly -> F2Poly
abs    = F2Poly -> F2Poly
forall a. a -> a
id
  signum :: F2Poly -> F2Poly
signum = F2Poly -> F2Poly -> F2Poly
forall a b. a -> b -> a
const (Vector Bit -> F2Poly
F2Poly (Bit -> Vector Bit
forall a. Unbox a => a -> Vector a
U.singleton (Bool -> Bit
Bit Bool
True)))
  * :: F2Poly -> F2Poly -> F2Poly
(*) = (Vector Bit -> Vector Bit -> Vector Bit)
-> F2Poly -> F2Poly -> F2Poly
coerce ((Vector Bit -> Vector Bit
dropWhileEnd (Vector Bit -> Vector Bit)
-> (Vector Bit -> Vector Bit) -> Vector Bit -> Vector Bit
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Vector Bit -> Vector Bit) -> Vector Bit -> Vector Bit)
-> (Vector Bit -> Vector Bit -> Vector Bit)
-> Vector Bit
-> Vector Bit
-> Vector Bit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Bit -> Vector Bit -> Vector Bit
karatsuba)
#ifdef MIN_VERSION_ghc_bignum
  fromInteger !n = case n of
    IS i#
      | n < 0     -> throw Underflow
      | otherwise -> F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#))))
                     $ ByteArray (bigNatFromWord# (int2Word# i#))
    IP bn# -> F2Poly $ BitVec 0 (I# (word2Int# (integerLog2# n)) + 1) $ ByteArray bn#
    IN{}   -> throw Underflow
  {-# INLINE fromInteger #-}
#else
  fromInteger :: Integer -> F2Poly
fromInteger !Integer
n = case Integer
n of
    S# Int#
i#
      | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0     -> ArithException -> F2Poly
forall a e. Exception e => e -> a
throw ArithException
Underflow
      | Bool
otherwise -> Vector Bit -> F2Poly
F2Poly (Vector Bit -> F2Poly) -> Vector Bit -> F2Poly
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ByteArray -> Vector Bit
BitVec Int
0 (Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int# -> Int
I# (Word# -> Int#
word2Int# (Word# -> Word#
clz# (Int# -> Word#
int2Word# Int#
i#))))
                      (ByteArray -> Vector Bit) -> ByteArray -> Vector Bit
forall a b. (a -> b) -> a -> b
$ BigNat -> ByteArray
fromBigNat (BigNat -> ByteArray) -> BigNat -> ByteArray
forall a b. (a -> b) -> a -> b
$ Word# -> BigNat
wordToBigNat (Int# -> Word#
int2Word# Int#
i#)
    Jp# BigNat
bn# -> Vector Bit -> F2Poly
F2Poly (Vector Bit -> F2Poly) -> Vector Bit -> F2Poly
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ByteArray -> Vector Bit
BitVec Int
0 (Int# -> Int
I# (Integer -> Int#
integerLog2# Integer
n) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (ByteArray -> Vector Bit) -> ByteArray -> Vector Bit
forall a b. (a -> b) -> a -> b
$ BigNat -> ByteArray
fromBigNat BigNat
bn#
    Jn#{}   -> ArithException -> F2Poly
forall a e. Exception e => e -> a
throw ArithException
Underflow
  {-# INLINE fromInteger #-}
#endif

  {-# INLINE (+)         #-}
  {-# INLINE (-)         #-}
  {-# INLINE negate      #-}
  {-# INLINE abs         #-}
  {-# INLINE signum      #-}
  {-# INLINE (*)         #-}

instance Enum F2Poly where
  fromEnum :: F2Poly -> Int
fromEnum = F2Poly -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
#ifdef MIN_VERSION_ghc_bignum
  toEnum !(I# i#) = F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#))))
                           $ ByteArray (bigNatFromWord# (int2Word# i#))
#else
  toEnum :: Int -> F2Poly
toEnum !(I# Int#
i#) = Vector Bit -> F2Poly
F2Poly (Vector Bit -> F2Poly) -> Vector Bit -> F2Poly
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ByteArray -> Vector Bit
BitVec Int
0 (Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int# -> Int
I# (Word# -> Int#
word2Int# (Word# -> Word#
clz# (Int# -> Word#
int2Word# Int#
i#))))
                           (ByteArray -> Vector Bit) -> ByteArray -> Vector Bit
forall a b. (a -> b) -> a -> b
$ BigNat -> ByteArray
fromBigNat (BigNat -> ByteArray) -> BigNat -> ByteArray
forall a b. (a -> b) -> a -> b
$ Word# -> BigNat
wordToBigNat (Int# -> Word#
int2Word# Int#
i#)
#endif

instance Real F2Poly where
  toRational :: F2Poly -> Rational
toRational = F2Poly -> Rational
forall a b. (Integral a, Num b) => a -> b
fromIntegral

-- | 'toInteger' converts a binary polynomial, encoded as 'F2Poly',
-- to 'Integer' encoding.
instance Integral F2Poly where
#ifdef MIN_VERSION_ghc_bignum
  toInteger xs = integerFromBigNat# (bitsToByteArray (unF2Poly xs))
#else
  toInteger :: F2Poly -> Integer
toInteger F2Poly
xs = BigNat -> Integer
bigNatToInteger (ByteArray# -> BigNat
BN# (Vector Bit -> ByteArray#
bitsToByteArray (F2Poly -> Vector Bit
unF2Poly F2Poly
xs)))
#endif
  quotRem :: F2Poly -> F2Poly -> (F2Poly, F2Poly)
quotRem (F2Poly Vector Bit
xs) (F2Poly Vector Bit
ys) = (Vector Bit -> F2Poly
F2Poly (Vector Bit -> Vector Bit
dropWhileEnd Vector Bit
qs), Vector Bit -> F2Poly
F2Poly (Vector Bit -> Vector Bit
dropWhileEnd Vector Bit
rs))
    where
      (Vector Bit
qs, Vector Bit
rs) = Vector Bit -> Vector Bit -> (Vector Bit, Vector Bit)
quotRemBits Vector Bit
xs Vector Bit
ys
  divMod :: F2Poly -> F2Poly -> (F2Poly, F2Poly)
divMod = F2Poly -> F2Poly -> (F2Poly, F2Poly)
forall a. Integral a => a -> a -> (a, a)
quotRem
  mod :: F2Poly -> F2Poly -> F2Poly
mod = F2Poly -> F2Poly -> F2Poly
forall a. Integral a => a -> a -> a
rem

instance Show F2Poly where
  show :: F2Poly -> String
show = (:) Char
'0' ShowS -> (F2Poly -> String) -> F2Poly -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:) Char
'b' ShowS -> (F2Poly -> String) -> F2Poly -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> ShowS) -> String -> Integer -> String
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Integer -> (Int -> Char) -> Integer -> ShowS
forall a. (Integral a, Show a) => a -> (Int -> Char) -> a -> ShowS
showIntAtBase Integer
2 Int -> Char
intToDigit) String
"" (Integer -> String) -> (F2Poly -> Integer) -> F2Poly -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. F2Poly -> Integer
forall a. Integral a => a -> Integer
toInteger

-- | Inputs must be valid for wrapping into F2Poly: no trailing garbage is allowed.
xorBits
  :: U.Vector Bit
  -> U.Vector Bit
  -> U.Vector Bit
xorBits :: Vector Bit -> Vector Bit -> Vector Bit
xorBits (BitVec _ 0 _) Vector Bit
ys = Vector Bit
ys
xorBits Vector Bit
xs (BitVec _ 0 _) = Vector Bit
xs
-- GMP has platform-dependent ASM implementations for mpn_xor_n,
-- which are impossible to beat by native Haskell.
#ifdef MIN_VERSION_ghc_bignum
xorBits (BitVec 0 lx (ByteArray xarr)) (BitVec 0 ly (ByteArray yarr)) = case lx `compare` ly of
  LT -> BitVec 0 ly zs
  EQ -> dropWhileEnd $ BitVec 0 (lx `min` (sizeofByteArray zs `shiftL` 3)) zs
  GT -> BitVec 0 lx zs
  where
    zs = ByteArray (xarr `bigNatXor` yarr)
#else
xorBits (BitVec 0 lx xarr) (BitVec 0 ly yarr) = case Int
lx Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
ly of
  Ordering
LT -> Int -> Int -> ByteArray -> Vector Bit
BitVec Int
0 Int
ly ByteArray
zs
  Ordering
EQ -> Vector Bit -> Vector Bit
dropWhileEnd (Vector Bit -> Vector Bit) -> Vector Bit -> Vector Bit
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ByteArray -> Vector Bit
BitVec Int
0 (Int
lx Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` (ByteArray -> Int
sizeofByteArray ByteArray
zs Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
3)) ByteArray
zs
  Ordering
GT -> Int -> Int -> ByteArray -> Vector Bit
BitVec Int
0 Int
lx ByteArray
zs
  where
    zs :: ByteArray
zs = BigNat -> ByteArray
fromBigNat (ByteArray -> BigNat
toBigNat ByteArray
xarr BigNat -> BigNat -> BigNat
`xorBigNat` ByteArray -> BigNat
toBigNat ByteArray
yarr)
#endif
xorBits Vector Bit
xs Vector Bit
ys = Vector Bit -> Vector Bit
dropWhileEnd (Vector Bit -> Vector Bit) -> Vector Bit -> Vector Bit
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (Vector Bit)) -> Vector Bit
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Bit)) -> Vector Bit)
-> (forall s. ST s (Vector Bit)) -> Vector Bit
forall a b. (a -> b) -> a -> b
$ do
  let lx :: Int
lx = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs
      ly :: Int
ly = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
ys
      (Int
shorterLen, Int
longerLen, Vector Bit
longer) = if Int
lx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
ly then (Int
ly, Int
lx, Vector Bit
xs) else (Int
lx, Int
ly, Vector Bit
ys)
  MVector s Bit
zs <- Int -> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
MU.replicate Int
longerLen (Bool -> Bit
Bit Bool
False)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0, Int
wordSize .. Int
shorterLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
    MVector (PrimState (ST s)) Bit -> Int -> Word -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bit -> Int -> Word -> m ()
writeWord MVector s Bit
MVector (PrimState (ST s)) Bit
zs Int
i (Vector Bit -> Int -> Word
indexWord Vector Bit
xs Int
i Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Vector Bit -> Int -> Word
indexWord Vector Bit
ys Int
i)
  MVector (PrimState (ST s)) Bit -> Vector Bit -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.unsafeCopy (Int -> MVector s Bit -> MVector s Bit
forall a s. Unbox a => Int -> MVector s a -> MVector s a
MU.drop Int
shorterLen MVector s Bit
zs) (Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Vector a -> Vector a
U.drop Int
shorterLen Vector Bit
longer)
  MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
zs

-- | Must be >= 2 * wordSize.
karatsubaThreshold :: Int
karatsubaThreshold :: Int
karatsubaThreshold = Int
2048

karatsuba :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
karatsuba :: Vector Bit -> Vector Bit -> Vector Bit
karatsuba Vector Bit
xs Vector Bit
ys
  | Vector Bit
xs Vector Bit -> Vector Bit -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Bit
ys = Vector Bit -> Vector Bit
sqrBits Vector Bit
xs
  | Int
lenXs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
karatsubaThreshold Bool -> Bool -> Bool
|| Int
lenYs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
karatsubaThreshold
  = Vector Bit -> Vector Bit -> Vector Bit
mulBits Vector Bit
xs Vector Bit
ys
  | Bool
otherwise = (forall s. ST s (Vector Bit)) -> Vector Bit
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Bit)) -> Vector Bit)
-> (forall s. ST s (Vector Bit)) -> Vector Bit
forall a b. (a -> b) -> a -> b
$ do
    MVector s Bit
zs <- Int -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MU.unsafeNew Int
lenZs
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int -> Int
forall a. Bits a => a -> a
divWordSize (Int
lenZs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
      let z0 :: Word
z0  = Vector Bit -> Int -> Word
indexWord0 Vector Bit
zs0   Int
k
          z11 :: Word
z11 = Vector Bit -> Int -> Word
indexWord0 Vector Bit
zs11 (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m)
          z10 :: Word
z10 = Vector Bit -> Int -> Word
indexWord0 Vector Bit
zs0  (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m)
          z12 :: Word
z12 = Vector Bit -> Int -> Word
indexWord0 Vector Bit
zs2  (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m)
          z2 :: Word
z2  = Vector Bit -> Int -> Word
indexWord0 Vector Bit
zs2  (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m)
      MVector (PrimState (ST s)) Bit -> Int -> Word -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bit -> Int -> Word -> m ()
writeWord MVector s Bit
MVector (PrimState (ST s)) Bit
zs (Int -> Int
forall a. Bits a => a -> a
mulWordSize Int
k) (Word
z0 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Word
z11 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Word
z10 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Word
z12 Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Word
z2)
    MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
zs
  where
    lenXs :: Int
lenXs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs
    lenYs :: Int
lenYs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
ys
    lenZs :: Int
lenZs = Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

    m :: Int
m    = (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lenXs Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (Int
lgWordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    m' :: Int
m'   = Int -> Int
forall a. Bits a => a -> a
mulWordSize Int
m

    xs0 :: Vector Bit
xs0  = Int -> Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
0 Int
m' Vector Bit
xs
    xs1 :: Vector Bit
xs1  = Int -> Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
m' (Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m') Vector Bit
xs
    ys0 :: Vector Bit
ys0  = Int -> Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
0 Int
m' Vector Bit
ys
    ys1 :: Vector Bit
ys1  = Int -> Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
m' (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m') Vector Bit
ys

    xs01 :: Vector Bit
xs01 = Vector Bit -> Vector Bit -> Vector Bit
xorBits Vector Bit
xs0 Vector Bit
xs1
    ys01 :: Vector Bit
ys01 = Vector Bit -> Vector Bit -> Vector Bit
xorBits Vector Bit
ys0 Vector Bit
ys1
    zs0 :: Vector Bit
zs0  = Vector Bit -> Vector Bit -> Vector Bit
karatsuba Vector Bit
xs0 Vector Bit
ys0
    zs2 :: Vector Bit
zs2  = Vector Bit -> Vector Bit -> Vector Bit
karatsuba Vector Bit
xs1 Vector Bit
ys1
    zs11 :: Vector Bit
zs11 = Vector Bit -> Vector Bit -> Vector Bit
karatsuba Vector Bit
xs01 Vector Bit
ys01

indexWord0 :: U.Vector Bit -> Int -> Word
indexWord0 :: Vector Bit -> Int -> Word
indexWord0 Vector Bit
bv Int
i'
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
lenI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Word
0
  | Int
lenI Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
wordSize   = Word
word
  | Bool
otherwise          = Word
word Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Int -> Word
loMask Int
lenI
  where
    i :: Int
i     = Int -> Int
forall a. Bits a => a -> a
mulWordSize Int
i'
    lenI :: Int
lenI  = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
bv Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i
    word :: Word
word  = Vector Bit -> Int -> Word
indexWord Vector Bit
bv Int
i

mulBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
mulBits :: Vector Bit -> Vector Bit -> Vector Bit
mulBits Vector Bit
xs Vector Bit
ys
  | Int
lenXs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
lenYs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Vector Bit
forall a. Unbox a => Vector a
U.empty
  | Int
lenXs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenYs           = Vector Bit -> Vector Bit -> Vector Bit
mulBits' Vector Bit
xs Vector Bit
ys
  | Bool
otherwise                = Vector Bit -> Vector Bit -> Vector Bit
mulBits' Vector Bit
ys Vector Bit
xs
  where
    lenXs :: Int
lenXs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs
    lenYs :: Int
lenYs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
ys

mulBits' :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
mulBits' :: Vector Bit -> Vector Bit -> Vector Bit
mulBits' Vector Bit
xs Vector Bit
ys = (forall s. ST s (Vector Bit)) -> Vector Bit
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Bit)) -> Vector Bit)
-> (forall s. ST s (Vector Bit)) -> Vector Bit
forall a b. (a -> b) -> a -> b
$ do
  MVector s Bit
zs <- Int -> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
MU.replicate Int
lenZs (Bool -> Bit
Bit Bool
False)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k ->
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bit -> Bool
unBit (Vector Bit -> Int -> Bit
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Bit
ys Int
k)) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
      (forall a. Bits a => a -> a -> a)
-> Vector Bit -> MVector (PrimState (ST s)) Bit -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
(forall a. Bits a => a -> a -> a)
-> Vector Bit -> MVector (PrimState m) Bit -> m ()
zipInPlace forall a. Bits a => a -> a -> a
xor Vector Bit
xs (Int -> Int -> MVector s Bit -> MVector s Bit
forall a s. Unbox a => Int -> Int -> MVector s a -> MVector s a
MU.unsafeSlice Int
k (Int
lenZs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) MVector s Bit
zs)
  MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
zs
  where
    lenXs :: Int
lenXs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs
    lenYs :: Int
lenYs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
ys
    lenZs :: Int
lenZs = Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

sqrBits :: U.Vector Bit -> U.Vector Bit
sqrBits :: Vector Bit -> Vector Bit
sqrBits Vector Bit
xs = (forall s. ST s (Vector Bit)) -> Vector Bit
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Bit)) -> Vector Bit)
-> (forall s. ST s (Vector Bit)) -> Vector Bit
forall a b. (a -> b) -> a -> b
$ do
  let lenXs :: Int
lenXs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs
  MVector s Bit
zs <- Int -> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
MU.replicate (Int -> Int
forall a. Bits a => a -> a
mulWordSize (Int -> Int
nWords Int
lenXs Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
1)) (Bool -> Bit
Bit Bool
False)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0, Int
wordSize .. Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    let (Word
z0, Word
z1) = Word -> (Word, Word)
sparseBits (Vector Bit -> Int -> Word
indexWord Vector Bit
xs Int
i)
    MVector (PrimState (ST s)) Bit -> Int -> Word -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bit -> Int -> Word -> m ()
writeWord MVector s Bit
MVector (PrimState (ST s)) Bit
zs (Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) Word
z0
    MVector (PrimState (ST s)) Bit -> Int -> Word -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Bit -> Int -> Word -> m ()
writeWord MVector s Bit
MVector (PrimState (ST s)) Bit
zs ((Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
wordSize) Word
z1
  MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
zs

quotRemBits :: U.Vector Bit -> U.Vector Bit -> (U.Vector Bit, U.Vector Bit)
quotRemBits :: Vector Bit -> Vector Bit -> (Vector Bit, Vector Bit)
quotRemBits Vector Bit
xs Vector Bit
ys
  | Vector Bit -> Bool
forall a. Unbox a => Vector a -> Bool
U.null Vector Bit
ys = ArithException -> (Vector Bit, Vector Bit)
forall a e. Exception e => e -> a
throw ArithException
DivideByZero
  | Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
ys = (Vector Bit
forall a. Unbox a => Vector a
U.empty, Vector Bit
xs)
  | Bool
otherwise = (forall s. ST s (Vector Bit, Vector Bit))
-> (Vector Bit, Vector Bit)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Bit, Vector Bit))
 -> (Vector Bit, Vector Bit))
-> (forall s. ST s (Vector Bit, Vector Bit))
-> (Vector Bit, Vector Bit)
forall a b. (a -> b) -> a -> b
$ do
    let lenXs :: Int
lenXs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs
        lenYs :: Int
lenYs = Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
ys
        lenQs :: Int
lenQs = Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    MVector s Bit
qs <- Int -> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
MU.replicate Int
lenQs (Bool -> Bit
Bit Bool
False)
    MVector s Bit
rs <- Int -> Bit -> ST s (MVector (PrimState (ST s)) Bit)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
MU.replicate Int
lenXs (Bool -> Bit
Bit Bool
False)
    MVector (PrimState (ST s)) Bit -> Vector Bit -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.unsafeCopy MVector s Bit
MVector (PrimState (ST s)) Bit
rs Vector Bit
xs
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
lenQs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
lenQs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 .. Int
0] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Bit Bool
r <- MVector (PrimState (ST s)) Bit -> Int -> ST s Bit
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MU.unsafeRead MVector s Bit
MVector (PrimState (ST s)) Bit
rs (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
r (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) Bit -> Int -> Bit -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MU.unsafeWrite MVector s Bit
MVector (PrimState (ST s)) Bit
qs Int
i (Bool -> Bit
Bit Bool
True)
        (forall a. Bits a => a -> a -> a)
-> Vector Bit -> MVector (PrimState (ST s)) Bit -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
(forall a. Bits a => a -> a -> a)
-> Vector Bit -> MVector (PrimState m) Bit -> m ()
zipInPlace forall a. Bits a => a -> a -> a
xor Vector Bit
ys (Int -> MVector s Bit -> MVector s Bit
forall a s. Unbox a => Int -> MVector s a -> MVector s a
MU.drop Int
i MVector s Bit
rs)
    let rs' :: MVector s Bit
rs' = Int -> Int -> MVector s Bit -> MVector s Bit
forall a s. Unbox a => Int -> Int -> MVector s a -> MVector s a
MU.unsafeSlice Int
0 Int
lenYs MVector s Bit
rs
    (,) (Vector Bit -> Vector Bit -> (Vector Bit, Vector Bit))
-> ST s (Vector Bit)
-> ST s (Vector Bit -> (Vector Bit, Vector Bit))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
qs ST s (Vector Bit -> (Vector Bit, Vector Bit))
-> ST s (Vector Bit) -> ST s (Vector Bit, Vector Bit)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MVector (PrimState (ST s)) Bit -> ST s (Vector Bit)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Bit
MVector (PrimState (ST s)) Bit
rs'

dropWhileEnd
  :: U.Vector Bit
  -> U.Vector Bit
dropWhileEnd :: Vector Bit -> Vector Bit
dropWhileEnd Vector Bit
xs = Int -> Int -> Vector Bit -> Vector Bit
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
0 (Int -> Int
go (Vector Bit -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Bit
xs)) Vector Bit
xs
  where
    go :: Int -> Int
go Int
n
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
wordSize = Int
wordSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros (Vector Bit -> Int -> Word
indexWord Vector Bit
xs Int
0 Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Int -> Word
loMask Int
n)
      | Bool
otherwise    = case Vector Bit -> Int -> Word
indexWord Vector Bit
xs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
wordSize) of
        Word
0 -> Int -> Int
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
wordSize)
        Word
w -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros Word
w

bitsToByteArray :: U.Vector Bit -> ByteArray#
bitsToByteArray :: Vector Bit -> ByteArray#
bitsToByteArray Vector Bit
xs = ByteArray#
arr
  where
    ys :: Vector Word
ys = if Vector Bit -> Bool
forall a. Unbox a => Vector a -> Bool
U.null Vector Bit
xs then Word -> Vector Word
forall a. Unbox a => a -> Vector a
U.singleton Word
0 else Vector Bit -> Vector Word
cloneToWords Vector Bit
xs
    !(P.Vector Int
_ Int
_ (ByteArray ByteArray#
arr)) = Vector Word -> Vector Word
toPrimVector Vector Word
ys

#ifdef MIN_VERSION_ghc_bignum
#else
fromBigNat :: BigNat -> ByteArray
fromBigNat :: BigNat -> ByteArray
fromBigNat (BN# ByteArray#
arr) = ByteArray# -> ByteArray
ByteArray ByteArray#
arr

toBigNat :: ByteArray -> BigNat
toBigNat :: ByteArray -> BigNat
toBigNat (ByteArray ByteArray#
arr) = ByteArray# -> BigNat
BN# ByteArray#
arr
#endif

-- | Execute the extended Euclidean algorithm.
-- For polynomials @a@ and @b@, compute their unique greatest common divisor @g@
-- and the unique coefficient polynomial @s@ satisfying @as + bt = g@.
--
-- >>> :set -XBinaryLiterals
-- >>> gcdExt 0b101 0b0101
-- (0b101,0b0)
-- >>> gcdExt 0b11 0b111
-- (0b1,0b10)
gcdExt :: F2Poly -> F2Poly -> (F2Poly, F2Poly)
gcdExt :: F2Poly -> F2Poly -> (F2Poly, F2Poly)
gcdExt = F2Poly -> F2Poly -> F2Poly -> F2Poly -> (F2Poly, F2Poly)
forall t. Integral t => t -> t -> t -> t -> (t, t)
go F2Poly
1 F2Poly
0
  where
    go :: t -> t -> t -> t -> (t, t)
go t
s t
s' t
r t
r'
      | t
r' t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0   = (t
r, t
s)
      | Bool
otherwise = case t -> t -> (t, t)
forall a. Integral a => a -> a -> (a, a)
quotRem t
r t
r' of
        (t
q, t
r'') -> t -> t -> t -> t -> (t, t)
go t
s' (t
s t -> t -> t
forall a. Num a => a -> a -> a
- t
q t -> t -> t
forall a. Num a => a -> a -> a
* t
s') t
r' t
r''