{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Poly.Internal.Dense.Field
( fieldGcd
) where
import Prelude hiding (quotRem, quot, rem, gcd, recip)
import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Euclidean (Euclidean(..), Field)
import Data.Field (recip)
import Data.Semiring (times, minus, zero, one)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
import Data.Poly.Internal.Dense
import Data.Poly.Internal.Dense.GcdDomain ()
instance (Eq a, Eq (v a), Field a, G.Vector v a) => Euclidean (Poly v a) where
degree (Poly xs) = fromIntegral (G.length xs)
quotRem (Poly xs) (Poly ys) = (toPoly' qs, toPoly' rs)
where
(qs, rs) = quotientAndRemainder xs ys
{-# INLINE quotRem #-}
rem (Poly xs) (Poly ys) = toPoly' $ remainder xs ys
{-# INLINE rem #-}
quotientAndRemainder
:: (Eq a, Field a, G.Vector v a)
=> v a
-> v a
-> (v a, v a)
quotientAndRemainder xs ys
| lenXs < lenYs = (G.empty, xs)
| lenYs == 0 = throw DivideByZero
| lenYs == 1 = let invY = recip (G.unsafeHead ys) in
(G.map (`times` invY) xs, G.empty)
| otherwise = runST $ do
qs <- MG.unsafeNew lenQs
rs <- MG.unsafeNew lenXs
G.unsafeCopy rs xs
let yLast = G.unsafeLast ys
invYLast = recip yLast
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
r <- MG.unsafeRead rs (lenYs - 1 + i)
let q = if yLast == one then r else r `times` invYLast
MG.unsafeWrite qs i q
MG.unsafeWrite rs (lenYs - 1 + i) zero
forM_ [0 .. lenYs - 2] $ \k -> do
let y = G.unsafeIndex ys k
when (y /= zero) $
MG.unsafeModify rs (\c -> c `minus` q `times` y) (i + k)
let rs' = MG.unsafeSlice 0 lenYs rs
(,) <$> G.unsafeFreeze qs <*> G.unsafeFreeze rs'
where
lenXs = G.length xs
lenYs = G.length ys
lenQs = lenXs - lenYs + 1
{-# INLINABLE quotientAndRemainder #-}
remainder
:: (Eq a, Field a, G.Vector v a)
=> v a
-> v a
-> v a
remainder xs ys
| G.null ys = throw DivideByZero
| otherwise = runST $ do
rs <- G.thaw xs
ys' <- G.unsafeThaw ys
remainderM rs ys'
G.unsafeFreeze $ MG.unsafeSlice 0 (G.length xs `min` G.length ys) rs
{-# INLINABLE remainder #-}
remainderM
:: (PrimMonad m, Eq a, Field a, G.Vector v a)
=> G.Mutable v (PrimState m) a
-> G.Mutable v (PrimState m) a
-> m ()
remainderM xs ys
| lenXs < lenYs = pure ()
| lenYs == 0 = throw DivideByZero
| lenYs == 1 = MG.set xs zero
| otherwise = do
yLast <- MG.unsafeRead ys (lenYs - 1)
let invYLast = recip yLast
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
r <- MG.unsafeRead xs (lenYs - 1 + i)
MG.unsafeWrite xs (lenYs - 1 + i) zero
let q = if yLast == one then r else r `times` invYLast
forM_ [0 .. lenYs - 2] $ \k -> do
y <- MG.unsafeRead ys k
when (y /= zero) $
MG.unsafeModify xs (\c -> c `minus` q `times` y) (i + k)
where
lenXs = MG.length xs
lenYs = MG.length ys
lenQs = lenXs - lenYs + 1
{-# INLINABLE remainderM #-}
fieldGcd
:: (Eq a, Field a, G.Vector v a)
=> Poly v a
-> Poly v a
-> Poly v a
fieldGcd (Poly xs) (Poly ys) = toPoly' $ runST $ do
xs' <- G.thaw xs
ys' <- G.thaw ys
gcdM xs' ys'
{-# INLINE fieldGcd #-}
gcdM
:: (PrimMonad m, Eq a, Field a, G.Vector v a)
=> G.Mutable v (PrimState m) a
-> G.Mutable v (PrimState m) a
-> m (v a)
gcdM xs ys = do
ys' <- dropWhileEndM (== zero) ys
if MG.null ys' then G.unsafeFreeze xs else do
remainderM xs ys'
gcdM ys' xs
{-# INLINE gcdM #-}