{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Poly.Internal.Dense
( Poly(..)
, VPoly
, UPoly
, leading
, dropWhileEndM
, toPoly
, monomial
, scale
, pattern X
, eval
, deriv
, integral
, toPoly'
, monomial'
, scale'
, pattern X'
, eval'
, deriv'
) where
import Prelude hiding (quotRem, rem, gcd, lcm, (^))
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.List (foldl', intersperse)
import Data.Semiring (Semiring(..))
import qualified Data.Semiring as Semiring
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
import qualified Data.Vector.Unboxed as U
import GHC.Exts
#if !MIN_VERSION_semirings(0,4,0)
import Data.Semigroup
import Numeric.Natural
#endif
newtype Poly v a = Poly
{ unPoly :: v a
}
deriving (Eq, Ord)
instance (Eq a, Semiring a, G.Vector v a) => IsList (Poly v a) where
type Item (Poly v a) = a
fromList = toPoly' . G.fromList
fromListN = (toPoly' .) . G.fromListN
toList = G.toList . unPoly
instance (Show a, G.Vector v a) => Show (Poly v a) where
showsPrec d (Poly xs)
| G.null xs
= showString "0"
| G.length xs == 1
= showsPrec d (G.head xs)
| otherwise
= showParen (d > 0)
$ foldl (.) id
$ intersperse (showString " + ")
$ G.ifoldl (\acc i c -> showCoeff i c : acc) [] xs
where
showCoeff 0 c = showsPrec 7 c
showCoeff 1 c = showsPrec 7 c . showString " * X"
showCoeff i c = showsPrec 7 c . showString " * X^" . showsPrec 7 i
type VPoly = Poly V.Vector
type UPoly = Poly U.Vector
toPoly :: (Eq a, Num a, G.Vector v a) => v a -> Poly v a
toPoly = Poly . dropWhileEnd (== 0)
toPoly' :: (Eq a, Semiring a, G.Vector v a) => v a -> Poly v a
toPoly' = Poly . dropWhileEnd (== zero)
leading :: G.Vector v a => Poly v a -> Maybe (Word, a)
leading (Poly v)
| G.null v = Nothing
| otherwise = Just (fromIntegral (G.length v - 1), G.last v)
instance (Eq a, Num a, G.Vector v a) => Num (Poly v a) where
Poly xs + Poly ys = toPoly $ plusPoly (+) xs ys
Poly xs - Poly ys = toPoly $ minusPoly negate (-) xs ys
negate (Poly xs) = Poly $ G.map negate xs
abs = id
signum = const 1
fromInteger n = case fromInteger n of
0 -> Poly $ G.empty
m -> Poly $ G.singleton m
Poly xs * Poly ys = toPoly $ karatsuba xs ys
{-# INLINE (+) #-}
{-# INLINE (-) #-}
{-# INLINE negate #-}
{-# INLINE fromInteger #-}
{-# INLINE (*) #-}
instance (Eq a, Semiring a, G.Vector v a) => Semiring (Poly v a) where
zero = Poly G.empty
one
| (one :: a) == zero = zero
| otherwise = Poly $ G.singleton one
plus (Poly xs) (Poly ys) = toPoly' $ plusPoly plus xs ys
times (Poly xs) (Poly ys) = toPoly' $ convolution zero plus times xs ys
{-# INLINE zero #-}
{-# INLINE one #-}
{-# INLINE plus #-}
{-# INLINE times #-}
#if MIN_VERSION_semirings(0,4,0)
fromNatural n = if n' == zero then zero else Poly $ G.singleton n'
where
n' :: a
n' = fromNatural n
{-# INLINE fromNatural #-}
#endif
instance (Eq a, Semiring.Ring a, G.Vector v a) => Semiring.Ring (Poly v a) where
negate (Poly xs) = Poly $ G.map Semiring.negate xs
dropWhileEnd
:: G.Vector v a
=> (a -> Bool)
-> v a
-> v a
dropWhileEnd p xs = G.basicUnsafeSlice 0 (go (G.basicLength xs)) xs
where
go 0 = 0
go n = if p (G.unsafeIndex xs (n - 1)) then go (n - 1) else n
{-# INLINE dropWhileEnd #-}
dropWhileEndM
:: (PrimMonad m, G.Vector v a)
=> (a -> Bool)
-> G.Mutable v (PrimState m) a
-> m (G.Mutable v (PrimState m) a)
dropWhileEndM p xs = go (MG.basicLength xs)
where
go 0 = pure $ MG.basicUnsafeSlice 0 0 xs
go n = do
x <- MG.unsafeRead xs (n - 1)
if p x then go (n - 1) else pure (MG.basicUnsafeSlice 0 n xs)
{-# INLINE dropWhileEndM #-}
plusPoly
:: G.Vector v a
=> (a -> a -> a)
-> v a
-> v a
-> v a
plusPoly add xs ys = runST $ do
let lenXs = G.basicLength xs
lenYs = G.basicLength ys
lenMn = lenXs `min` lenYs
lenMx = lenXs `max` lenYs
zs <- MG.basicUnsafeNew lenMx
forM_ [0 .. lenMn - 1] $ \i ->
MG.unsafeWrite zs i (add (G.unsafeIndex xs i) (G.unsafeIndex ys i))
G.unsafeCopy
(MG.basicUnsafeSlice lenMn (lenMx - lenMn) zs)
(G.basicUnsafeSlice lenMn (lenMx - lenMn) (if lenXs <= lenYs then ys else xs))
G.unsafeFreeze zs
{-# INLINE plusPoly #-}
minusPoly
:: G.Vector v a
=> (a -> a)
-> (a -> a -> a)
-> v a
-> v a
-> v a
minusPoly neg sub xs ys = runST $ do
let lenXs = G.basicLength xs
lenYs = G.basicLength ys
lenMn = lenXs `min` lenYs
lenMx = lenXs `max` lenYs
zs <- MG.basicUnsafeNew lenMx
forM_ [0 .. lenMn - 1] $ \i ->
MG.unsafeWrite zs i (sub (G.unsafeIndex xs i) (G.unsafeIndex ys i))
if lenXs < lenYs
then forM_ [lenXs .. lenYs - 1] $ \i ->
MG.unsafeWrite zs i (neg (G.unsafeIndex ys i))
else G.unsafeCopy
(MG.basicUnsafeSlice lenYs (lenXs - lenYs) zs)
(G.basicUnsafeSlice lenYs (lenXs - lenYs) xs)
G.unsafeFreeze zs
{-# INLINE minusPoly #-}
karatsubaThreshold :: Int
karatsubaThreshold = 32
karatsuba
:: (Eq a, Num a, G.Vector v a)
=> v a
-> v a
-> v a
karatsuba xs ys
| lenXs <= karatsubaThreshold || lenYs <= karatsubaThreshold
= convolution 0 (+) (*) xs ys
| otherwise = runST $ do
zs <- MG.basicUnsafeNew lenZs
forM_ [0 .. lenZs - 1] $ \k -> do
let z0 = if k < G.basicLength zs0
then G.unsafeIndex zs0 k
else 0
z11 = if k - m >= 0 && k - m < G.basicLength zs11
then G.unsafeIndex zs11 (k - m)
else 0
z10 = if k - m >= 0 && k - m < G.basicLength zs0
then G.unsafeIndex zs0 (k - m)
else 0
z12 = if k - m >= 0 && k - m < G.basicLength zs2
then G.unsafeIndex zs2 (k - m)
else 0
z2 = if k - 2 * m >= 0 && k - 2 * m < G.basicLength zs2
then G.unsafeIndex zs2 (k - 2 * m)
else 0
MG.unsafeWrite zs k (z0 + (z11 - z10 - z12) + z2)
G.unsafeFreeze zs
where
lenXs = G.basicLength xs
lenYs = G.basicLength ys
lenZs = lenXs + lenYs - 1
m = ((lenXs `min` lenYs) + 1) `quot` 2
xs0 = G.slice 0 m xs
xs1 = G.slice m (lenXs - m) xs
ys0 = G.slice 0 m ys
ys1 = G.slice m (lenYs - m) ys
xs01 = plusPoly (+) xs0 xs1
ys01 = plusPoly (+) ys0 ys1
zs0 = karatsuba xs0 ys0
zs2 = karatsuba xs1 ys1
zs11 = karatsuba xs01 ys01
{-# INLINE karatsuba #-}
convolution
:: G.Vector v a
=> a
-> (a -> a -> a)
-> (a -> a -> a)
-> v a
-> v a
-> v a
convolution zer add mul xs ys
| G.null xs || G.null ys = G.empty
| otherwise = runST $ do
let lenXs = G.basicLength xs
lenYs = G.basicLength ys
lenZs = lenXs + lenYs - 1
zs <- MG.basicUnsafeNew lenZs
forM_ [0 .. lenZs - 1] $ \k -> do
let is = [max (k - lenYs + 1) 0 .. min k (lenXs - 1)]
acc = foldl' add zer $ flip map is $ \i ->
mul (G.unsafeIndex xs i) (G.unsafeIndex ys (k - i))
MG.unsafeWrite zs k acc
G.unsafeFreeze zs
{-# INLINE convolution #-}
monomial :: (Eq a, Num a, G.Vector v a) => Word -> a -> Poly v a
monomial _ 0 = Poly G.empty
monomial p c = Poly $ G.generate (fromIntegral p + 1) (\i -> if i == fromIntegral p then c else 0)
{-# INLINE monomial #-}
monomial' :: (Eq a, Semiring a, G.Vector v a) => Word -> a -> Poly v a
monomial' p c
| c == zero = Poly G.empty
| otherwise = Poly $ G.generate (fromIntegral p + 1) (\i -> if i == fromIntegral p then c else zero)
{-# INLINE monomial' #-}
scaleInternal
:: (Eq a, G.Vector v a)
=> a
-> (a -> a -> a)
-> Word
-> a
-> v a
-> v a
scaleInternal zer mul yp yc xs = runST $ do
let lenXs = G.basicLength xs
zs <- MG.basicUnsafeNew (fromIntegral yp + lenXs)
forM_ [0 .. fromIntegral yp - 1] $ \k ->
MG.unsafeWrite zs k zer
forM_ [0 .. lenXs - 1] $ \k ->
MG.unsafeWrite zs (fromIntegral yp + k) (mul yc $ G.unsafeIndex xs k)
G.unsafeFreeze zs
{-# INLINE scaleInternal #-}
scale :: (Eq a, Num a, G.Vector v a) => Word -> a -> Poly v a -> Poly v a
scale yp yc (Poly xs) = toPoly $ scaleInternal 0 (*) yp yc xs
scale' :: (Eq a, Semiring a, G.Vector v a) => Word -> a -> Poly v a -> Poly v a
scale' yp yc (Poly xs) = toPoly' $ scaleInternal zero times yp yc xs
data StrictPair a b = !a :*: !b
infixr 1 :*:
fst' :: StrictPair a b -> a
fst' (a :*: _) = a
eval :: (Num a, G.Vector v a) => Poly v a -> a -> a
eval (Poly cs) x = fst' $
G.foldl' (\(acc :*: xn) cn -> (acc + cn * xn :*: x * xn)) (0 :*: 1) cs
{-# INLINE eval #-}
eval' :: (Semiring a, G.Vector v a) => Poly v a -> a -> a
eval' (Poly cs) x = fst' $
G.foldl' (\(acc :*: xn) cn -> (acc `plus` cn `times` xn :*: x `times` xn)) (zero :*: one) cs
{-# INLINE eval' #-}
deriv :: (Eq a, Num a, G.Vector v a) => Poly v a -> Poly v a
deriv (Poly xs)
| G.null xs = Poly G.empty
| otherwise = toPoly $ G.imap (\i x -> fromIntegral (i + 1) * x) $ G.tail xs
{-# INLINE deriv #-}
deriv' :: (Eq a, Semiring a, G.Vector v a) => Poly v a -> Poly v a
deriv' (Poly xs)
| G.null xs = Poly G.empty
| otherwise = toPoly' $ G.imap (\i x -> fromNatural (fromIntegral (i + 1)) `times` x) $ G.tail xs
{-# INLINE deriv' #-}
#if !MIN_VERSION_semirings(0,4,0)
fromNatural :: Semiring a => Natural -> a
fromNatural 0 = zero
fromNatural n = getAdd' (stimes n (Add' one))
newtype Add' a = Add' { getAdd' :: a }
instance Semiring a => Semigroup (Add' a) where
Add' a <> Add' b = Add' (a `plus` b)
#endif
integral :: (Eq a, Fractional a, G.Vector v a) => Poly v a -> Poly v a
integral (Poly xs)
| G.null xs = Poly G.empty
| otherwise = toPoly $ runST $ do
zs <- MG.basicUnsafeNew (lenXs + 1)
MG.unsafeWrite zs 0 0
forM_ [0 .. lenXs - 1] $ \i ->
MG.unsafeWrite zs (i + 1) (G.unsafeIndex xs i * recip (fromIntegral i + 1))
G.unsafeFreeze zs
where
lenXs = G.basicLength xs
{-# INLINE integral #-}
pattern X :: (Eq a, Num a, G.Vector v a, Eq (v a)) => Poly v a
pattern X <- ((==) var -> True)
where X = var
var :: forall a v. (Eq a, Num a, G.Vector v a, Eq (v a)) => Poly v a
var
| (1 :: a) == 0 = Poly G.empty
| otherwise = Poly $ G.fromList [0, 1]
{-# INLINE var #-}
pattern X' :: (Eq a, Semiring a, G.Vector v a, Eq (v a)) => Poly v a
pattern X' <- ((==) var' -> True)
where X' = var'
var' :: forall a v. (Eq a, Semiring a, G.Vector v a, Eq (v a)) => Poly v a
var'
| (one :: a) == zero = Poly G.empty
| otherwise = Poly $ G.fromList [zero, one]
{-# INLINE var' #-}