{-# LANGUAGE AllowAmbiguousTypes          #-}
{-# LANGUAGE DeriveAnyClass               #-}
{-# LANGUAGE NoGeneralisedNewtypeDeriving #-}
{-# LANGUAGE TypeApplications             #-}

module ZkFold.Base.Algebra.Polynomials.Univariate
    ( toPoly
    , fromPoly
    , Poly
    , removeZeros
    , scaleP
    , qr
    , eea
    , lt
    , deg
    , vec2poly
    , PolyVec
    , fromPolyVec
    , toPolyVec
    , rewrapPolyVec
    , castPolyVec
    , evalPolyVec
    , scalePV
    , polyVecZero
    , polyVecDiv
    , polyVecLinear
    , polyVecLagrange
    , polyVecGrandProduct
    , polyVecInLagrangeBasis
    , polyVecQuadratic
    , mulVector
    , mulDft
    , mulKaratsuba
    , mulPoly
    , mulPolyKaratsuba
    , mulPolyDft
    , mulPolyNaive
    ) where

import           Control.DeepSeq                  (NFData (..))
import qualified Data.Vector                      as V
import           GHC.Generics                     (Generic)
import           Numeric.Natural                  (Natural)
import           Prelude                          hiding (Num (..), drop, length, product, replicate, sum, take, (/),
                                                   (^))
import qualified Prelude                          as P
import           Test.QuickCheck                  (Arbitrary (..), chooseInt)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.DFT    (genericDft)
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Prelude                   (zipWithDefault)

-------------------------------- Arbitrary degree polynomials --------------------------------

-- TODO (Issue #17): hide constructor
newtype Poly c = P (V.Vector c)
    deriving (Poly c -> Poly c -> Bool
(Poly c -> Poly c -> Bool)
-> (Poly c -> Poly c -> Bool) -> Eq (Poly c)
forall c. Eq c => Poly c -> Poly c -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall c. Eq c => Poly c -> Poly c -> Bool
== :: Poly c -> Poly c -> Bool
$c/= :: forall c. Eq c => Poly c -> Poly c -> Bool
/= :: Poly c -> Poly c -> Bool
Eq, Int -> Poly c -> ShowS
[Poly c] -> ShowS
Poly c -> String
(Int -> Poly c -> ShowS)
-> (Poly c -> String) -> ([Poly c] -> ShowS) -> Show (Poly c)
forall c. Show c => Int -> Poly c -> ShowS
forall c. Show c => [Poly c] -> ShowS
forall c. Show c => Poly c -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall c. Show c => Int -> Poly c -> ShowS
showsPrec :: Int -> Poly c -> ShowS
$cshow :: forall c. Show c => Poly c -> String
show :: Poly c -> String
$cshowList :: forall c. Show c => [Poly c] -> ShowS
showList :: [Poly c] -> ShowS
Show, (forall a b. (a -> b) -> Poly a -> Poly b)
-> (forall a b. a -> Poly b -> Poly a) -> Functor Poly
forall a b. a -> Poly b -> Poly a
forall a b. (a -> b) -> Poly a -> Poly b
forall (f :: Type -> Type).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Poly a -> Poly b
fmap :: forall a b. (a -> b) -> Poly a -> Poly b
$c<$ :: forall a b. a -> Poly b -> Poly a
<$ :: forall a b. a -> Poly b -> Poly a
Functor, (forall x. Poly c -> Rep (Poly c) x)
-> (forall x. Rep (Poly c) x -> Poly c) -> Generic (Poly c)
forall x. Rep (Poly c) x -> Poly c
forall x. Poly c -> Rep (Poly c) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall c x. Rep (Poly c) x -> Poly c
forall c x. Poly c -> Rep (Poly c) x
$cfrom :: forall c x. Poly c -> Rep (Poly c) x
from :: forall x. Poly c -> Rep (Poly c) x
$cto :: forall c x. Rep (Poly c) x -> Poly c
to :: forall x. Rep (Poly c) x -> Poly c
Generic, Poly c -> ()
(Poly c -> ()) -> NFData (Poly c)
forall c. NFData c => Poly c -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall c. NFData c => Poly c -> ()
rnf :: Poly c -> ()
NFData)

toPoly :: (Ring c, Eq c) => V.Vector c -> Poly c
toPoly :: forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly = Poly c -> Poly c
forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (Poly c -> Poly c) -> (Vector c -> Poly c) -> Vector c -> Poly c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector c -> Poly c
forall c. Vector c -> Poly c
P

fromPoly :: Poly c -> V.Vector c
fromPoly :: forall c. Poly c -> Vector c
fromPoly (P Vector c
cs) = Vector c
cs

instance (Ring c, Eq c) => AdditiveSemigroup (Poly c) where
    P Vector c
l + :: Poly c -> Poly c -> Poly c
+ P Vector c
r = Poly c -> Poly c
forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (Poly c -> Poly c) -> Poly c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c -> Poly c
forall c. Vector c -> Poly c
P (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ (c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. AdditiveSemigroup a => a -> a -> a
(+) Vector c
lPadded Vector c
rPadded
      where
        len :: Int
len = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
l) (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
r)

        lPadded :: Vector c
lPadded = Vector c
l Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
l) c
forall a. AdditiveMonoid a => a
zero
        rPadded :: Vector c
rPadded = Vector c
r Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
r) c
forall a. AdditiveMonoid a => a
zero

instance (Ring c, Eq c) => AdditiveMonoid (Poly c) where
    zero :: Poly c
zero = Vector c -> Poly c
forall c. Vector c -> Poly c
P Vector c
forall a. Vector a
V.empty

instance (Ring c, Eq c) => AdditiveGroup (Poly c) where
    negate :: Poly c -> Poly c
negate (P Vector c
cs) = Vector c -> Poly c
forall c. Vector c -> Poly c
P (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap c -> c
forall a. AdditiveGroup a => a -> a
negate Vector c
cs

instance (Field c, Eq c) => MultiplicativeSemigroup (Poly c) where
    -- | If it is possible to calculate a primitive root of unity in the field, proceed with FFT multiplication.
    -- Otherwise default to Karatsuba multiplication for polynomials of degree higher than 64 or use naive multiplication otherwise.
    -- 64 is a threshold determined by benchmarking.
    P Vector c
l * :: Poly c -> Poly c -> Poly c
* P Vector c
r = Poly c -> Poly c
forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (Poly c -> Poly c) -> Poly c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c -> Poly c
forall c. Vector c -> Poly c
P (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c -> Vector c -> Vector c
forall c. Field c => Vector c -> Vector c -> Vector c
mulAdaptive Vector c
l Vector c
r

padVector :: forall a . Ring a => V.Vector a -> Int -> V.Vector a
padVector :: forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector a
v Int
l
  | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l = Vector a
v
  | Bool
otherwise = Vector a
v Vector a -> Vector a -> Vector a
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v) a
forall a. AdditiveMonoid a => a
zero

mulAdaptive :: forall c . Field c => V.Vector c -> V.Vector c -> V.Vector c
mulAdaptive :: forall c. Field c => Vector c -> Vector c -> Vector c
mulAdaptive Vector c
l Vector c
r
      | Vector c -> Bool
forall a. Vector a -> Bool
V.null Vector c
l = Vector c
forall a. Vector a
V.empty
      | Vector c -> Bool
forall a. Vector a -> Bool
V.null Vector c
r = Vector c
forall a. Vector a
V.empty
      | Bool
otherwise =
          case (Maybe c
maybeW2n, Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64) of
            (Maybe c
_, Bool
True)        -> Vector c -> Vector c -> Vector c
forall c. Field c => Vector c -> Vector c -> Vector c
mulVector Vector c
l Vector c
r
            (Just c
w2n, Bool
_)    -> Integer -> c -> Vector c -> Vector c -> Vector c
forall c.
Field c =>
Integer -> c -> Vector c -> Vector c -> Vector c
mulDft (Integer
p Integer -> Integer -> Integer
forall a. AdditiveSemigroup a => a -> a -> a
+ Integer
1) c
w2n Vector c
lPaddedDft Vector c
rPaddedDft
            (Maybe c
Nothing, Bool
False) -> Vector c -> Vector c -> Vector c
forall c. Field c => Vector c -> Vector c -> Vector c
mulKaratsuba Vector c
lPaddedKaratsuba Vector c
rPaddedKaratsuba
        where
            len :: Int
            len :: Int
len = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
l) (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
r)

            p :: Integer
            p :: Integer
p = forall a b. (RealFrac a, Integral b) => a -> b
ceiling @Double (Double -> Integer) -> Double -> Integer
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
2 (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)

            padKaratsuba :: Int
            padKaratsuba :: Int
padKaratsuba = Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ Integer
p

            padDft :: Int
            padDft :: Int
padDft = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
padKaratsuba

            lPaddedKaratsuba, rPaddedKaratsuba :: V.Vector c
            lPaddedKaratsuba :: Vector c
lPaddedKaratsuba = Vector c -> Int -> Vector c
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector c
l Int
padKaratsuba
            rPaddedKaratsuba :: Vector c
rPaddedKaratsuba = Vector c -> Int -> Vector c
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector c
r Int
padKaratsuba

            lPaddedDft, rPaddedDft :: V.Vector c
            lPaddedDft :: Vector c
lPaddedDft = Vector c -> Int -> Vector c
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector c
l Int
padDft
            rPaddedDft :: Vector c
rPaddedDft = Vector c -> Int -> Vector c
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector c
r Int
padDft

            maybeW2n :: Maybe c
            maybeW2n :: Maybe c
maybeW2n = Natural -> Maybe c
forall a. Field a => Natural -> Maybe a
rootOfUnity (Natural -> Maybe c) -> Natural -> Maybe c
forall a b. (a -> b) -> a -> b
$ Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.+ Integer
1)

mulDft :: forall c . Field c => Integer -> c -> V.Vector c -> V.Vector c -> V.Vector c
mulDft :: forall c.
Field c =>
Integer -> c -> Vector c -> Vector c -> Vector c
mulDft Integer
p c
w2n Vector c
lPadded Vector c
rPadded = Vector c
c
  where
    pad :: Int
    pad :: Int
pad = Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ Integer
p

    w2nInv :: c
    w2nInv :: c
w2nInv = c
forall a. MultiplicativeMonoid a => a
one c -> c -> c
forall a. Field a => a -> a -> a
// c
w2n

    nInv :: c
    nInv :: c
nInv = c
forall a. MultiplicativeMonoid a => a
one c -> c -> c
forall a. Field a => a -> a -> a
// Natural -> c
forall a b. FromConstant a b => a -> b
fromConstant (forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @Natural Int
pad)

    v1Image, v2Image :: V.Vector c
    v1Image :: Vector c
v1Image = Integer -> c -> Vector c -> Vector c
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft Integer
p c
w2n Vector c
lPadded
    v2Image :: Vector c
v2Image = Integer -> c -> Vector c -> Vector c
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft Integer
p c
w2n Vector c
rPadded

    cImage :: V.Vector c
    cImage :: Vector c
cImage = (c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector c
v1Image Vector c
v2Image

    c :: V.Vector c
    c :: Vector c
c = (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
nInv) (c -> c) -> Vector c -> Vector c
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> c -> Vector c -> Vector c
forall a. Ring a => Integer -> a -> Vector a -> Vector a
genericDft Integer
p c
w2nInv Vector c
cImage

mulKaratsuba :: forall a. Field a => V.Vector a -> V.Vector a -> V.Vector a
mulKaratsuba :: forall c. Field c => Vector c -> Vector c -> Vector c
mulKaratsuba Vector a
v1 Vector a
v2
  | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector a
v1 Vector a
v2
  | Bool
otherwise = Vector a
result
  where
    len :: Int
    len :: Int
len = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1

    n :: Int
    n :: Int
n = Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.div` Int
2

    a, b, c, d :: V.Vector a
    (Vector a
b, Vector a
a) = Int -> Vector a -> (Vector a, Vector a)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt Int
n Vector a
v1

    (Vector a
d, Vector a
c) = Int -> Vector a -> (Vector a, Vector a)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt Int
n Vector a
v2

    partLen :: Int
    partLen :: Int
partLen = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1


    ac, bd :: V.Vector a
    ac :: Vector a
ac = Vector a -> Int -> Vector a
forall a. Ring a => Vector a -> Int -> Vector a
padVector (Vector a -> Vector a -> Vector a
forall c. Field c => Vector c -> Vector c -> Vector c
mulAdaptive Vector a
a Vector a
c) Int
partLen
    bd :: Vector a
bd = Vector a -> Int -> Vector a
forall a. Ring a => Vector a -> Int -> Vector a
padVector (Vector a -> Vector a -> Vector a
forall c. Field c => Vector c -> Vector c -> Vector c
mulAdaptive Vector a
b Vector a
d) Int
partLen

    apb, cpd :: V.Vector a
    apb :: Vector a
apb = (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
(+) Vector a
a Vector a
b
    cpd :: Vector a
cpd = (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
(+) Vector a
c Vector a
d

    abcd :: V.Vector a
    abcd :: Vector a
abcd = Vector a -> Vector a -> Vector a
forall c. Field c => Vector c -> Vector c -> Vector c
mulAdaptive Vector a
apb Vector a
cpd

    mid :: V.Vector a
    mid :: Vector a
mid = (a -> a -> a -> a) -> Vector a -> Vector a -> Vector a -> Vector a
forall a b c d.
(a -> b -> c -> d) -> Vector a -> Vector b -> Vector c -> Vector d
V.zipWith3 (\a
x a
y a
z -> a
x a -> a -> a
forall a. AdditiveGroup a => a -> a -> a
- a
y a -> a -> a
forall a. AdditiveGroup a => a -> a -> a
- a
z) (Vector a -> Int -> Vector a
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector a
abcd Int
partLen) (Vector a -> Int -> Vector a
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector a
ac Int
partLen) (Vector a -> Int -> Vector a
forall a. Ring a => Vector a -> Int -> Vector a
padVector Vector a
bd Int
partLen)

    result :: V.Vector a
    result :: Vector a
result = Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1) Int -> a
ix2v

    ix2v :: Int -> a
    ix2v :: Int -> a
ix2v Int
ix
      | Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n = Vector a
bd Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
ix
      | Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1 = Vector a
bd Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
ix a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
+ Vector a
mid Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
n)
      | Int
ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1 = Vector a
mid Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1)
      | Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1 = Vector a
mid Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
n) a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
+ Vector a
ac Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
n)
      | Bool
otherwise = Vector a
ac Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.* Int
n)

mulVector :: forall a. Field a => V.Vector a -> V.Vector a -> V.Vector a
mulVector :: forall c. Field c => Vector c -> Vector c -> Vector c
mulVector Vector a
v1 Vector a
v2 = Vector a
result
  where
    len1 :: Int
len1 = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1
    len2 :: Int
len2 = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v2

    result :: Vector a
result = Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate (Int
len1 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
len2 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1) Int -> a
ix2v

    ix2v :: Int -> a
    ix2v :: Int -> a
ix2v Int
ix = Int -> Int -> a -> a
ix2v' Int
start1 Int
start2 a
forall a. AdditiveMonoid a => a
zero
      where
        start1 :: Int
start1 = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
ix (Int
len1 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1)
        start2 :: Int
start2 = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
len1 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
1)

    ix2v' :: Int -> Int -> a -> a
    ix2v' :: Int -> Int -> a -> a
ix2v' (-1) Int
_ a
accum                = a
accum
    ix2v' Int
_ ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len2) -> Bool
True) a
accum = a
accum
    ix2v' Int
i Int
j a
accum                   = Int -> Int -> a -> a
ix2v' (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
1) (a
accum a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
+ Vector a
v1 Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
i a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Vector a
v2 Vector a -> Int -> a
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
j)

instance (Field c, Eq c) => Exponent (Poly c) Natural where
    ^ :: Poly c -> Natural -> Poly c
(^) = Poly c -> Natural -> Poly c
forall a. MultiplicativeMonoid a => a -> Natural -> a
natPow

instance (Field c, Eq c) => MultiplicativeMonoid (Poly c) where
    one :: Poly c
one = Vector c -> Poly c
forall c. Vector c -> Poly c
P (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ c -> Vector c
forall a. a -> Vector a
V.singleton c
forall a. MultiplicativeMonoid a => a
one

instance (Ring c, Arbitrary c, Eq c) => Arbitrary (Poly c) where
    arbitrary :: Gen (Poly c)
arbitrary = (Vector c -> Poly c) -> Gen (Vector c) -> Gen (Poly c)
forall a b. (a -> b) -> Gen a -> Gen b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Gen (Vector c) -> Gen (Poly c)) -> Gen (Vector c) -> Gen (Poly c)
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Gen Int
chooseInt (Int
0, Int
128) Gen Int -> (Int -> Gen (Vector c)) -> Gen (Vector c)
forall a b. Gen a -> (a -> Gen b) -> Gen b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
n -> Int -> Gen c -> Gen (Vector c)
forall (m :: Type -> Type) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
n Gen c
forall a. Arbitrary a => Gen a
arbitrary

lt :: Poly c -> c
lt :: forall c. Poly c -> c
lt (P Vector c
cs) = Vector c -> c
forall a. Vector a -> a
V.last Vector c
cs

deg :: Poly c -> Integer
-- | Degree of zero polynomial is `-1`
deg :: forall c. Poly c -> Integer
deg (P Vector c
cs) = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
cs) Integer -> Integer -> Integer
forall a. AdditiveGroup a => a -> a -> a
- Integer
1

scaleP :: Ring c => c -> Natural -> Poly c -> Poly c
scaleP :: forall c. Ring c => c -> Natural -> Poly c -> Poly c
scaleP c
a Natural
n (P Vector c
cs) = Vector c -> Poly c
forall c. Vector c -> Poly c
P (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n) c
forall a. AdditiveMonoid a => a
zero Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c
a c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
*) Vector c
cs

qr :: (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
qr :: forall c. (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
qr Poly c
a Poly c
b = Poly c -> Poly c -> Poly c -> (Poly c, Poly c)
forall {c}.
(Field c, Eq c) =>
Poly c -> Poly c -> Poly c -> (Poly c, Poly c)
go Poly c
a Poly c
b Poly c
forall a. AdditiveMonoid a => a
zero
    where
        go :: Poly c -> Poly c -> Poly c -> (Poly c, Poly c)
go Poly c
x Poly c
y Poly c
q = if Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
y then (Poly c
q, Poly c
x) else Poly c -> Poly c -> Poly c -> (Poly c, Poly c)
go Poly c
x' Poly c
y Poly c
q'
            where
                c :: c
c = Poly c -> c
forall c. Poly c -> c
lt Poly c
x c -> c -> c
forall a. Field a => a -> a -> a
// Poly c -> c
forall c. Poly c -> c
lt Poly c
y
                n :: Natural
n = Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
x Integer -> Integer -> Integer
forall a. AdditiveGroup a => a -> a -> a
- Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
y)
                -- ^ if `deg x < deg y`, `n` is not evaluated, so this would not error out
                x' :: Poly c
x' = Poly c
x Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- c -> Natural -> Poly c -> Poly c
forall c. Ring c => c -> Natural -> Poly c -> Poly c
scaleP c
c Natural
n Poly c
y
                q' :: Poly c
q' = Poly c
q Poly c -> Poly c -> Poly c
forall a. AdditiveSemigroup a => a -> a -> a
+ c -> Natural -> Poly c -> Poly c
forall c. Ring c => c -> Natural -> Poly c -> Poly c
scaleP c
c Natural
n Poly c
forall a. MultiplicativeMonoid a => a
one

-- | Extended Euclidean algorithm.
eea :: (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
eea :: forall c. (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
eea Poly c
a Poly c
b = (Poly c, Poly c) -> (Poly c, Poly c) -> (Poly c, Poly c)
forall {c}.
(Field c, Eq c) =>
(Poly c, Poly c) -> (Poly c, Poly c) -> (Poly c, Poly c)
go (Poly c
a, Poly c
forall a. MultiplicativeMonoid a => a
one) (Poly c
b, Poly c
forall a. AdditiveMonoid a => a
zero)
    where
        go :: (Poly c, Poly c) -> (Poly c, Poly c) -> (Poly c, Poly c)
go (Poly c
x, Poly c
s) (Poly c
y, Poly c
t) = if Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
y Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== -Integer
1 then (Poly c
x, Poly c
s) else (Poly c, Poly c) -> (Poly c, Poly c) -> (Poly c, Poly c)
go (Poly c
y, Poly c
t) (Poly c
r, Poly c
s Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- Poly c
q Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
t)
            where
                (Poly c
q, Poly c
r) = Poly c -> Poly c -> (Poly c, Poly c)
forall c. (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
qr Poly c
x Poly c
y

---------------------------------- Fixed degree polynomials ----------------------------------

-- TODO (Issue #17): hide constructor
newtype PolyVec c (size :: Natural) = PV (V.Vector c)
    deriving (PolyVec c size -> PolyVec c size -> Bool
(PolyVec c size -> PolyVec c size -> Bool)
-> (PolyVec c size -> PolyVec c size -> Bool)
-> Eq (PolyVec c size)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall c (size :: Natural).
Eq c =>
PolyVec c size -> PolyVec c size -> Bool
$c== :: forall c (size :: Natural).
Eq c =>
PolyVec c size -> PolyVec c size -> Bool
== :: PolyVec c size -> PolyVec c size -> Bool
$c/= :: forall c (size :: Natural).
Eq c =>
PolyVec c size -> PolyVec c size -> Bool
/= :: PolyVec c size -> PolyVec c size -> Bool
Eq, Int -> PolyVec c size -> ShowS
[PolyVec c size] -> ShowS
PolyVec c size -> String
(Int -> PolyVec c size -> ShowS)
-> (PolyVec c size -> String)
-> ([PolyVec c size] -> ShowS)
-> Show (PolyVec c size)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall c (size :: Natural).
Show c =>
Int -> PolyVec c size -> ShowS
forall c (size :: Natural). Show c => [PolyVec c size] -> ShowS
forall c (size :: Natural). Show c => PolyVec c size -> String
$cshowsPrec :: forall c (size :: Natural).
Show c =>
Int -> PolyVec c size -> ShowS
showsPrec :: Int -> PolyVec c size -> ShowS
$cshow :: forall c (size :: Natural). Show c => PolyVec c size -> String
show :: PolyVec c size -> String
$cshowList :: forall c (size :: Natural). Show c => [PolyVec c size] -> ShowS
showList :: [PolyVec c size] -> ShowS
Show, (forall x. PolyVec c size -> Rep (PolyVec c size) x)
-> (forall x. Rep (PolyVec c size) x -> PolyVec c size)
-> Generic (PolyVec c size)
forall x. Rep (PolyVec c size) x -> PolyVec c size
forall x. PolyVec c size -> Rep (PolyVec c size) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall c (size :: Natural) x.
Rep (PolyVec c size) x -> PolyVec c size
forall c (size :: Natural) x.
PolyVec c size -> Rep (PolyVec c size) x
$cfrom :: forall c (size :: Natural) x.
PolyVec c size -> Rep (PolyVec c size) x
from :: forall x. PolyVec c size -> Rep (PolyVec c size) x
$cto :: forall c (size :: Natural) x.
Rep (PolyVec c size) x -> PolyVec c size
to :: forall x. Rep (PolyVec c size) x -> PolyVec c size
Generic, PolyVec c size -> ()
(PolyVec c size -> ()) -> NFData (PolyVec c size)
forall a. (a -> ()) -> NFData a
forall c (size :: Natural). NFData c => PolyVec c size -> ()
$crnf :: forall c (size :: Natural). NFData c => PolyVec c size -> ()
rnf :: PolyVec c size -> ()
NFData)

toPolyVec :: forall c size . (Ring c, KnownNat size) => V.Vector c -> PolyVec c size
toPolyVec :: forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> PolyVec c size
toPolyVec = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size)
-> (Vector c -> Vector c) -> Vector c -> PolyVec c size
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.take (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Natural). KnownNat n => Natural
value @size)) (Vector c -> Vector c)
-> (Vector c -> Vector c) -> Vector c -> Vector c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> Vector c
addZeros @c @size

fromPolyVec :: PolyVec c size -> V.Vector c
fromPolyVec :: forall c (size :: Natural). PolyVec c size -> Vector c
fromPolyVec (PV Vector c
cs) = Vector c
cs

rewrapPolyVec :: (V.Vector c -> V.Vector c) -> PolyVec c size -> PolyVec c size
rewrapPolyVec :: forall c (size :: Natural).
(Vector c -> Vector c) -> PolyVec c size -> PolyVec c size
rewrapPolyVec Vector c -> Vector c
f (PV Vector c
x) = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> Vector c
f Vector c
x)

poly2vec :: forall c size . (Ring c, KnownNat size) => Poly c -> PolyVec c size
poly2vec :: forall c (size :: Natural).
(Ring c, KnownNat size) =>
Poly c -> PolyVec c size
poly2vec (P Vector c
cs) = Vector c -> PolyVec c size
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> PolyVec c size
toPolyVec Vector c
cs

vec2poly :: (Ring c, Eq c) => PolyVec c size -> Poly c
vec2poly :: forall c (size :: Natural).
(Ring c, Eq c) =>
PolyVec c size -> Poly c
vec2poly (PV Vector c
cs) = Poly c -> Poly c
forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (Poly c -> Poly c) -> Poly c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c -> Poly c
forall c. Vector c -> Poly c
P Vector c
cs

instance Scale c' c => Scale c' (PolyVec c size) where
    scale :: c' -> PolyVec c size -> PolyVec c size
scale c'
c (PV Vector c
p) = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (c' -> Vector c -> Vector c
forall b a. Scale b a => b -> a -> a
scale c'
c Vector c
p)

instance Ring c => AdditiveSemigroup (PolyVec c size) where
    PV Vector c
l + :: PolyVec c size -> PolyVec c size -> PolyVec c size
+ PV Vector c
r = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ (c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. AdditiveSemigroup a => a -> a -> a
(+) Vector c
l Vector c
r

instance (Ring c, KnownNat size) => AdditiveMonoid (PolyVec c size) where
    zero :: PolyVec c size
zero = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Natural). KnownNat n => Natural
value @size)) c
forall a. AdditiveMonoid a => a
zero

instance (Ring c, KnownNat size) => AdditiveGroup (PolyVec c size) where
    negate :: PolyVec c size -> PolyVec c size
negate (PV Vector c
cs) = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap c -> c
forall a. AdditiveGroup a => a -> a
negate Vector c
cs

instance (Field c, KnownNat size, Eq c) => Exponent (PolyVec c size) Natural where
    ^ :: PolyVec c size -> Natural -> PolyVec c size
(^) = PolyVec c size -> Natural -> PolyVec c size
forall a. MultiplicativeMonoid a => a -> Natural -> a
natPow

-- TODO (Issue #18): check for overflow
instance (Field c, KnownNat size, Eq c) => MultiplicativeSemigroup (PolyVec c size) where
    PolyVec c size
l * :: PolyVec c size -> PolyVec c size -> PolyVec c size
* PolyVec c size
r = Poly c -> PolyVec c size
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Poly c -> PolyVec c size
poly2vec (Poly c -> PolyVec c size) -> Poly c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ PolyVec c size -> Poly c
forall c (size :: Natural).
(Ring c, Eq c) =>
PolyVec c size -> Poly c
vec2poly PolyVec c size
l Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* PolyVec c size -> Poly c
forall c (size :: Natural).
(Ring c, Eq c) =>
PolyVec c size -> Poly c
vec2poly PolyVec c size
r

instance (Field c, KnownNat size, Eq c) => MultiplicativeMonoid (PolyVec c size) where
    one :: PolyVec c size
one = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ c -> Vector c
forall a. a -> Vector a
V.singleton c
forall a. MultiplicativeMonoid a => a
one Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Natural). KnownNat n => Natural
value @size Natural -> Natural -> Natural
-! Natural
1)) c
forall a. AdditiveMonoid a => a
zero

instance (Ring c, Arbitrary c, KnownNat size) => Arbitrary (PolyVec c size) where
    arbitrary :: Gen (PolyVec c size)
arbitrary = Vector c -> PolyVec c size
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> PolyVec c size
toPolyVec (Vector c -> PolyVec c size)
-> Gen (Vector c) -> Gen (PolyVec c size)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Gen c -> Gen (Vector c)
forall (m :: Type -> Type) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
value @size) Gen c
forall a. Arbitrary a => Gen a
arbitrary

-- p(x) = a0 + a1 * x
polyVecLinear :: forall c size . (Ring c, KnownNat size) => c -> c -> PolyVec c size
polyVecLinear :: forall c (size :: Natural).
(Ring c, KnownNat size) =>
c -> c -> PolyVec c size
polyVecLinear c
a0 c
a1 = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ [c] -> Vector c
forall a. [a] -> Vector a
V.fromList [c
a0, c
a1] Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
value @size Natural -> Natural -> Natural
-! Natural
2) c
forall a. AdditiveMonoid a => a
zero

-- p(x) = a0 + a1 * x + a2 * x^2
polyVecQuadratic :: forall c size . (Ring c, KnownNat size) => c -> c -> c -> PolyVec c size
polyVecQuadratic :: forall c (size :: Natural).
(Ring c, KnownNat size) =>
c -> c -> c -> PolyVec c size
polyVecQuadratic c
a0 c
a1 c
a2 = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ [c] -> Vector c
forall a. [a] -> Vector a
V.fromList [c
a0, c
a1, c
a2] Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural). KnownNat n => Natural
value @size Natural -> Natural -> Natural
-! Natural
3) c
forall a. AdditiveMonoid a => a
zero

scalePV :: Ring c => c -> PolyVec c size -> PolyVec c size
scalePV :: forall c (size :: Natural).
Ring c =>
c -> PolyVec c size -> PolyVec c size
scalePV c
c (PV Vector c
as) = Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV (Vector c -> PolyVec c size) -> Vector c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
*c
c) Vector c
as

evalPolyVec :: Ring c => PolyVec c size -> c -> c
evalPolyVec :: forall c (size :: Natural). Ring c => PolyVec c size -> c -> c
evalPolyVec (PV Vector c
cs) c
x = Vector c -> c
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Vector c -> c) -> Vector c -> c
forall a b. (a -> b) -> a -> b
$ (c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector c
cs (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ (Natural -> c) -> Vector Natural -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c
xc -> Natural -> c
forall a b. Exponent a b => a -> b -> a
^) (Int -> (Int -> Natural) -> Vector Natural
forall a. Int -> (Int -> a) -> Vector a
V.generate (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
cs) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @Natural))

castPolyVec :: forall c size size' . (Ring c, KnownNat size, KnownNat size', Eq c) => PolyVec c size -> PolyVec c size'
castPolyVec :: forall c (size :: Natural) (size' :: Natural).
(Ring c, KnownNat size, KnownNat size', Eq c) =>
PolyVec c size -> PolyVec c size'
castPolyVec (PV Vector c
cs)
    | forall (n :: Natural). KnownNat n => Natural
value @size Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= forall (n :: Natural). KnownNat n => Natural
value @size'                             = Vector c -> PolyVec c size'
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> PolyVec c size
toPolyVec Vector c
cs
    | (c -> Bool) -> Vector c -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero) (Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Natural). KnownNat n => Natural
value @size')) Vector c
cs) = Vector c -> PolyVec c size'
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> PolyVec c size
toPolyVec Vector c
cs
    | Bool
otherwise = String -> PolyVec c size'
forall a. HasCallStack => String -> a
error String
"castPolyVec: Cannot cast polynomial vector to smaller size!"

-- p(x) = x^n - 1
polyVecZero :: forall c size size' . (Field c, KnownNat size, KnownNat size', Eq c) => PolyVec c size'
polyVecZero :: forall c (size :: Natural) (size' :: Natural).
(Field c, KnownNat size, KnownNat size', Eq c) =>
PolyVec c size'
polyVecZero = Poly c -> PolyVec c size'
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Poly c -> PolyVec c size
poly2vec (Poly c -> PolyVec c size') -> Poly c -> PolyVec c size'
forall a b. (a -> b) -> a -> b
$ c -> Natural -> Poly c -> Poly c
forall c. Ring c => c -> Natural -> Poly c -> Poly c
scaleP c
forall a. MultiplicativeMonoid a => a
one (forall (n :: Natural). KnownNat n => Natural
value @size) Poly c
forall a. MultiplicativeMonoid a => a
one Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- Poly c
forall a. MultiplicativeMonoid a => a
one

-- L_i(x) : p(omega^i) = 1, p(omega^j) = 0, j /= i, 1 <= i <= n, 1 <= j <= n
polyVecLagrange :: forall c size size' . (Field c, Eq c, KnownNat size, KnownNat size') =>
    Natural -> c -> PolyVec c size'
polyVecLagrange :: forall c (size :: Natural) (size' :: Natural).
(Field c, Eq c, KnownNat size, KnownNat size') =>
Natural -> c -> PolyVec c size'
polyVecLagrange Natural
i c
omega = c -> PolyVec c size' -> PolyVec c size'
forall c (size :: Natural).
Ring c =>
c -> PolyVec c size -> PolyVec c size
scalePV (c
omegac -> Natural -> c
forall a b. Exponent a b => a -> b -> a
^Natural
i c -> c -> c
forall a. Field a => a -> a -> a
// Natural -> c
forall a b. FromConstant a b => a -> b
fromConstant (forall (n :: Natural). KnownNat n => Natural
value @size)) (PolyVec c size' -> PolyVec c size')
-> PolyVec c size' -> PolyVec c size'
forall a b. (a -> b) -> a -> b
$ (forall c (size :: Natural) (size' :: Natural).
(Field c, KnownNat size, KnownNat size', Eq c) =>
PolyVec c size'
polyVecZero @c @size @size' PolyVec c size' -> PolyVec c size' -> PolyVec c size'
forall a. AdditiveGroup a => a -> a -> a
- PolyVec c size'
forall a. MultiplicativeMonoid a => a
one) PolyVec c size' -> PolyVec c size' -> PolyVec c size'
forall c (size :: Natural).
(Field c, KnownNat size, Eq c) =>
PolyVec c size -> PolyVec c size -> PolyVec c size
`polyVecDiv` c -> c -> PolyVec c size'
forall c (size :: Natural).
(Ring c, KnownNat size) =>
c -> c -> PolyVec c size
polyVecLinear (c -> c
forall a. AdditiveGroup a => a -> a
negate (c -> c) -> c -> c
forall a b. (a -> b) -> a -> b
$ c
omegac -> Natural -> c
forall a b. Exponent a b => a -> b -> a
^Natural
i) c
forall a. MultiplicativeMonoid a => a
one

-- p(x) = c_1 * L_1(x) + c_2 * L_2(x) + ... + c_n * L_n(x)
polyVecInLagrangeBasis :: forall c size size' . (Field c, Eq c, KnownNat size, KnownNat size') =>
    c -> PolyVec c size -> PolyVec c size'
polyVecInLagrangeBasis :: forall c (size :: Natural) (size' :: Natural).
(Field c, Eq c, KnownNat size, KnownNat size') =>
c -> PolyVec c size -> PolyVec c size'
polyVecInLagrangeBasis c
omega (PV Vector c
cs) =
    let ls :: Vector (PolyVec c size')
ls = (Natural -> PolyVec c size')
-> Vector Natural -> Vector (PolyVec c size')
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Natural
i -> forall c (size :: Natural) (size' :: Natural).
(Field c, Eq c, KnownNat size, KnownNat size') =>
Natural -> c -> PolyVec c size'
polyVecLagrange @c @size @size' Natural
i c
omega) (Int -> (Int -> Natural) -> Vector Natural
forall a. Int -> (Int -> a) -> Vector a
V.generate (Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
cs) (Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Natural) -> (Int -> Int) -> Int -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int
forall a. Enum a => a -> a
succ))
    in Vector (PolyVec c size') -> PolyVec c size'
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Vector (PolyVec c size') -> PolyVec c size')
-> Vector (PolyVec c size') -> PolyVec c size'
forall a b. (a -> b) -> a -> b
$ (c -> PolyVec c size' -> PolyVec c size')
-> Vector c -> Vector (PolyVec c size') -> Vector (PolyVec c size')
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> PolyVec c size' -> PolyVec c size'
forall c (size :: Natural).
Ring c =>
c -> PolyVec c size -> PolyVec c size
scalePV Vector c
cs Vector (PolyVec c size')
ls

polyVecGrandProduct :: forall c size . (Field c, KnownNat size) =>
    PolyVec c size -> PolyVec c size -> PolyVec c size -> c -> c -> PolyVec c size
polyVecGrandProduct :: forall c (size :: Natural).
(Field c, KnownNat size) =>
PolyVec c size
-> PolyVec c size -> PolyVec c size -> c -> c -> PolyVec c size
polyVecGrandProduct (PV Vector c
as) (PV Vector c
bs) (PV Vector c
sigmas) c
beta c
gamma =
    let ps :: Vector c
ps = (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c -> c -> c
forall a. AdditiveSemigroup a => a -> a -> a
+ c
gamma) ((c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. AdditiveSemigroup a => a -> a -> a
(+) Vector c
as ((c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
beta) Vector c
bs))
        qs :: Vector c
qs = (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c -> c -> c
forall a. AdditiveSemigroup a => a -> a -> a
+ c
gamma) ((c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. AdditiveSemigroup a => a -> a -> a
(+) Vector c
as ((c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
beta) Vector c
sigmas))
        zs :: Vector c
zs = (Int -> c) -> Vector Int -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Vector c -> c
forall (t :: Type -> Type) a.
(Foldable t, MultiplicativeMonoid a) =>
t a -> a
product (Vector c -> c) -> (Int -> Vector c) -> Int -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Vector c -> Vector c) -> Vector c -> Int -> Vector c
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.take ((c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. Field a => a -> a -> a
(//) Vector c
ps Vector c
qs)) (Int -> (Int -> Int) -> Vector Int
forall a. Int -> (Int -> a) -> Vector a
V.generate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Natural). KnownNat n => Natural
value @size)) Int -> Int
forall a. a -> a
id)
    in Vector c -> PolyVec c size
forall c (size :: Natural). Vector c -> PolyVec c size
PV Vector c
zs

polyVecDiv :: forall c size . (Field c, KnownNat size, Eq c) =>
    PolyVec c size -> PolyVec c size -> PolyVec c size
polyVecDiv :: forall c (size :: Natural).
(Field c, KnownNat size, Eq c) =>
PolyVec c size -> PolyVec c size -> PolyVec c size
polyVecDiv PolyVec c size
l PolyVec c size
r = Poly c -> PolyVec c size
forall c (size :: Natural).
(Ring c, KnownNat size) =>
Poly c -> PolyVec c size
poly2vec (Poly c -> PolyVec c size) -> Poly c -> PolyVec c size
forall a b. (a -> b) -> a -> b
$ (Poly c, Poly c) -> Poly c
forall a b. (a, b) -> a
fst ((Poly c, Poly c) -> Poly c) -> (Poly c, Poly c) -> Poly c
forall a b. (a -> b) -> a -> b
$ Poly c -> Poly c -> (Poly c, Poly c)
forall c. (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
qr (PolyVec c size -> Poly c
forall c (size :: Natural).
(Ring c, Eq c) =>
PolyVec c size -> Poly c
vec2poly PolyVec c size
l) (PolyVec c size -> Poly c
forall c (size :: Natural).
(Ring c, Eq c) =>
PolyVec c size -> Poly c
vec2poly PolyVec c size
r)

-------------------------------- Helper functions --------------------------------

removeZeros :: (Ring c, Eq c) => Poly c -> Poly c
removeZeros :: forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (P Vector c
cs)
  | Vector c -> Bool
forall a. Vector a -> Bool
V.null Vector c
cs = Vector c -> Poly c
forall c. Vector c -> Poly c
P Vector c
cs
  | Bool
otherwise = Vector c -> Poly c
forall c. Vector c -> Poly c
P (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int -> Int
traverseZeros Int
startIx) Vector c
cs
    where
        startIx :: Int
        startIx :: Int
startIx = Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
cs Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1

        traverseZeros :: Int -> Int
        traverseZeros :: Int -> Int
traverseZeros Int
0
          | Vector c -> c
forall a. Vector a -> a
V.head Vector c
cs c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero = -Int
1
          | Bool
otherwise = Int
0
        traverseZeros Int
n
          | Vector c
cs Vector c -> Int -> c
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
n c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero = Int -> Int
traverseZeros (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1)
          | Bool
otherwise = Int
n

addZeros :: forall c size . (Ring c, KnownNat size) => V.Vector c -> V.Vector c
addZeros :: forall c (size :: Natural).
(Ring c, KnownNat size) =>
Vector c -> Vector c
addZeros Vector c
cs = Vector c
cs Vector c -> Vector c -> Vector c
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> c -> Vector c
forall a. Int -> a -> Vector a
V.replicate (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Natural). KnownNat n => Natural
value @size) Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
cs) c
forall a. AdditiveMonoid a => a
zero


-- ** THE CODE BELOW IS ONLY USED FOR BENCHMARKING MULTIPLICATION **

-- | Naive vector multiplication, O(n^2)
--
mulPoly :: forall a. Field a => Poly a -> Poly a -> Poly a
mulPoly :: forall a. Field a => Poly a -> Poly a -> Poly a
mulPoly (P Vector a
v1) (P Vector a
v2) = Vector a -> Poly a
forall c. Vector c -> Poly c
P (Vector a -> Poly a) -> Vector a -> Poly a
forall a b. (a -> b) -> a -> b
$ Vector a -> Vector a -> Vector a
forall c. Field c => Vector c -> Vector c -> Vector c
mulVector Vector a
v1 Vector a
v2

-- | Adaptation of Karatsuba's algorithm. O(n^log_2(3))
--
mulPolyKaratsuba :: (Eq a, Field a) => Poly a -> Poly a -> Poly a
mulPolyKaratsuba :: forall a. (Eq a, Field a) => Poly a -> Poly a -> Poly a
mulPolyKaratsuba (P Vector a
v1) (P Vector a
v2) = Poly a -> Poly a
forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (Poly a -> Poly a) -> Poly a -> Poly a
forall a b. (a -> b) -> a -> b
$ Vector a -> Poly a
forall c. Vector c -> Poly c
P Vector a
result
  where
    l :: Int
l = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1) (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v2)
    p :: Integer
p = forall a b. (RealFrac a, Integral b) => a -> b
ceiling @Double @Integer (Double -> Integer) -> Double -> Integer
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
2 (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l)

    pad :: Int
pad = Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ Integer
p

    result :: Vector a
result = Vector a -> Vector a -> Vector a
forall c. Field c => Vector c -> Vector c -> Vector c
mulKaratsuba
        (Vector a
v1 Vector a -> Vector a -> Vector a
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
pad Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1) a
forall a. AdditiveMonoid a => a
zero)
        (Vector a
v2 Vector a -> Vector a -> Vector a
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
pad Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v2) a
forall a. AdditiveMonoid a => a
zero)

-- DFT multiplication of vectors. O(nlogn)
--
mulPolyDft :: forall a . (Eq a, Field a) => Poly a -> Poly a -> Poly a
mulPolyDft :: forall a. (Eq a, Field a) => Poly a -> Poly a -> Poly a
mulPolyDft (P Vector a
v1) (P Vector a
v2) = Poly a -> Poly a
forall c. (Ring c, Eq c) => Poly c -> Poly c
removeZeros (Poly a -> Poly a) -> Poly a -> Poly a
forall a b. (a -> b) -> a -> b
$ Vector a -> Poly a
forall c. Vector c -> Poly c
P Vector a
result
  where
    l :: Int
l = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1) (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v2)
    p :: Integer
p = (forall a b. (RealFrac a, Integral b) => a -> b
ceiling @Double (Double -> Integer) -> Double -> Integer
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
2 (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l)) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
P.+ Integer
1

    w2n :: a
    w2n :: a
w2n = case Natural -> Maybe a
forall a. Field a => Natural -> Maybe a
rootOfUnity (Natural -> Maybe a) -> Natural -> Maybe a
forall a b. (a -> b) -> a -> b
$ Integer -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
p of
            Just a
a -> a
a
            Maybe a
_      -> a
forall a. HasCallStack => a
undefined

    pad :: Int
pad = Int
2 Int -> Integer -> Int
forall a b. (Num a, Integral b) => a -> b -> a
P.^ Integer
p

    result :: Vector a
result = forall c.
Field c =>
Integer -> c -> Vector c -> Vector c -> Vector c
mulDft @a Integer
p a
w2n
        (Vector a
v1 Vector a -> Vector a -> Vector a
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
pad Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v1) a
forall a. AdditiveMonoid a => a
zero)
        (Vector a
v2 Vector a -> Vector a -> Vector a
forall a. Vector a -> Vector a -> Vector a
V.++ Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate (Int
pad Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v2) a
forall a. AdditiveMonoid a => a
zero)

mulPolyNaive :: Field a => Poly a -> Poly a -> Poly a
mulPolyNaive :: forall a. Field a => Poly a -> Poly a -> Poly a
mulPolyNaive (P Vector a
v1) (P Vector a
v2) = Vector a -> Poly a
forall c. Vector c -> Poly c
P (Vector a -> Poly a) -> Vector a -> Poly a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList ([a] -> Vector a) -> [a] -> Vector a
forall a b. (a -> b) -> a -> b
$ [a] -> [a] -> [a]
forall {a}.
(AdditiveMonoid a, MultiplicativeSemigroup a) =>
[a] -> [a] -> [a]
go (Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
v1) (Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
v2)
    where
        go :: [a] -> [a] -> [a]
go [] [a]
_      = []
        go (a
x:[a]
xs) [a]
ys = (a -> a -> a) -> a -> a -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> a -> b -> [a] -> [b] -> [c]
zipWithDefault a -> a -> a
forall a. AdditiveSemigroup a => a -> a -> a
(+) a
forall a. AdditiveMonoid a => a
zero a
forall a. AdditiveMonoid a => a
zero ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a
x a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
*) [a]
ys) (a
forall a. AdditiveMonoid a => a
zero a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a]
go [a]
xs [a]
ys)