{-# LANGUAGE BangPatterns, ViewPatterns #-}
module Math.BernsteinPoly
(BernsteinPoly(..), bernsteinSubsegment, listToBernstein, zeroPoly,
(~*), (*~), (~+), (~-), degreeElevate, bernsteinSplit, bernsteinEval,
bernsteinEvalDeriv, binCoeff, convolve, bernsteinEvalDerivs, bernsteinDeriv)
where
import Data.Vector.Unboxed as V
import Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector as B
data BernsteinPoly a = BernsteinPoly {
bernsteinCoeffs :: V.Vector a}
deriving Show
data ScaledPoly a = ScaledPoly {
scaledCoeffs :: V.Vector a }
deriving Show
infixl 7 ~*, *~
infixl 6 ~+, ~-
{-# RULES "toScaled/fromScaled" forall a. toScaled (fromScaled a) = a;
"fromScaled/toScaled" forall a. fromScaled (toScaled a) = a; #-}
toScaled :: (Unbox a, Num a) => BernsteinPoly a -> ScaledPoly a
toScaled (BernsteinPoly v) =
ScaledPoly $
V.zipWith (*) v $ binCoeff $ V.length v - 1
{-# NOINLINE [2] toScaled #-}
fromScaled :: (Unbox a, Fractional a) => ScaledPoly a -> BernsteinPoly a
fromScaled (ScaledPoly v) =
BernsteinPoly $
V.zipWith (/) v $ binCoeff $ V.length v - 1
{-# NOINLINE [2] fromScaled #-}
listToBernstein :: (Unbox a, Num a) => [a] -> BernsteinPoly a
listToBernstein [] = zeroPoly
listToBernstein l = BernsteinPoly $ V.fromList l
{-# INLINE listToBernstein #-}
zeroPoly :: (Num a, Unbox a) => BernsteinPoly a
zeroPoly = BernsteinPoly $ V.fromList [0]
{-# SPECIALIZE zeroPoly :: BernsteinPoly Double #-}
bernsteinSubsegment :: (Unbox a, Ord a, Fractional a) =>
BernsteinPoly a -> a -> a -> BernsteinPoly a
bernsteinSubsegment b t1 t2
| t1 > t2 = bernsteinSubsegment b t2 t1
| otherwise = snd $ flip bernsteinSplit (t1/t2) $
fst $ bernsteinSplit b t2
{-# INLINE bernsteinSubsegment #-}
convolve :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a
convolve x h = V.create $ do
let xN = V.length x
let hN = V.length h
let xIndices = V.enumFromN 0 xN
let hIndices = V.enumFromN 0 hN
xM <- V.unsafeThaw x
hM <- V.unsafeThaw h
yM <- M.replicate (xN + hN - 1) 0
V.forM_ xIndices $ \i -> do
a <- M.unsafeRead xM i
V.forM_ hIndices $ \j -> do
b <- M.unsafeRead hM j
M.unsafeModify yM (+ a * b) (i + j)
return yM
{-# SPECIALIZE convolve :: Vector Double -> Vector Double -> Vector Double #-}
(~*) :: (Unbox a, Fractional a) =>
BernsteinPoly a -> BernsteinPoly a -> BernsteinPoly a
(toScaled -> a) ~* (toScaled -> b) =
fromScaled $ mulScaled a b
{-# INLINE (~*) #-}
mulScaled :: (Unbox a, Num a) => ScaledPoly a -> ScaledPoly a -> ScaledPoly a
mulScaled (ScaledPoly a) (ScaledPoly b) =
ScaledPoly $ convolve a b
{-# INLINE mulScaled #-}
binCoeff :: (Num a, Unbox a) => Int -> V.Vector a
binCoeff n = V.map fromIntegral $
V.scanl (\x m -> x * (n-m+1) `quot` m)
1 (V.enumFromN 1 n)
{-# INLINE binCoeff #-}
degreeElevateScaled :: (Unbox a, Num a)
=> ScaledPoly a -> Int -> ScaledPoly a
degreeElevateScaled b@(ScaledPoly p) times
| times <= 0 = b
| otherwise = ScaledPoly $ convolve (binCoeff times) p
{-# SPECIALIZE degreeElevateScaled :: ScaledPoly Double ->
Int -> ScaledPoly Double #-}
degreeElevate :: (Unbox a, Fractional a)
=> BernsteinPoly a -> Int -> BernsteinPoly a
degreeElevate (toScaled -> b) times =
fromScaled (degreeElevateScaled b times)
{-# INLINE degreeElevate #-}
bernsteinEval :: (Unbox a, Fractional a)
=> BernsteinPoly a -> a -> a
bernsteinEval (BernsteinPoly v) _
| V.length v == 0 = 0
bernsteinEval (BernsteinPoly v) _
| V.length v == 1 = V.unsafeHead v
bernsteinEval (BernsteinPoly v) t =
go t (fromIntegral n) (V.unsafeIndex v 0 * u) 1
where u = 1-t
n = fromIntegral $ V.length v - 1
go !tn !bc !tmp !i
| i == n = tmp + tn*V.unsafeIndex v n
| otherwise =
go (tn*t)
(bc*fromIntegral (n-i)/(fromIntegral i + 1))
((tmp + tn*bc*V.unsafeIndex v i)*u)
(i+1)
{-# SPECIALIZE bernsteinEval :: BernsteinPoly Double -> Double -> Double #-}
bernsteinEvalDeriv :: (Unbox t, Fractional t) => BernsteinPoly t -> t -> (t,t)
bernsteinEvalDeriv b@(BernsteinPoly v) t
| V.length v <= 1 = (V.unsafeHead v, 0)
| otherwise = (bernsteinEval b t, bernsteinEval (bernsteinDeriv b) t)
{-# INLINE bernsteinEvalDeriv #-}
bernsteinEvalDerivs :: (Unbox t, Fractional t) => BernsteinPoly t -> t -> [t]
bernsteinEvalDerivs b@(BernsteinPoly v) t
| V.length v <= 1 = [V.unsafeHead v, 0]
| otherwise = bernsteinEval b t :
bernsteinEvalDerivs (bernsteinDeriv b) t
{-# INLINE bernsteinEvalDerivs #-}
bernsteinDeriv :: (Unbox a, Num a) => BernsteinPoly a -> BernsteinPoly a
bernsteinDeriv (BernsteinPoly v)
| V.length v == 0 = zeroPoly
bernsteinDeriv (BernsteinPoly v) =
BernsteinPoly $
V.map (* fromIntegral (V.length v - 1)) $
V.zipWith (-) (V.tail v) v
{-# SPECIALIZE bernsteinDeriv :: BernsteinPoly Double ->
BernsteinPoly Double #-}
bernsteinSplit :: (Unbox a, Num a) =>
BernsteinPoly a -> a -> (BernsteinPoly a, BernsteinPoly a)
bernsteinSplit (BernsteinPoly v) t =
(BernsteinPoly $ convert $
B.map V.head interpVecs,
BernsteinPoly $ V.reverse $ convert $
B.map V.last $ convert interpVecs)
where
interp a b = (1-t)*a + t*b
interpVecs = B.iterateN (V.length v) interpVec v
interpVec v2 = V.zipWith interp v2 (V.tail v2)
{-# SPECIALIZE bernsteinSplit :: BernsteinPoly Double -> Double ->
(BernsteinPoly Double, BernsteinPoly Double) #-}
addScaled :: (Unbox a, Num a) => ScaledPoly a -> ScaledPoly a -> ScaledPoly a
addScaled ba@(ScaledPoly a) bb@(ScaledPoly b)
| la < lb = ScaledPoly $
V.zipWith (+) (scaledCoeffs $ degreeElevateScaled ba $ lb-la) b
| la > lb = ScaledPoly $
V.zipWith (+) a (scaledCoeffs $ degreeElevateScaled bb $ la-lb)
| otherwise = ScaledPoly $ V.zipWith (+) a b
where la = V.length a
lb = V.length b
{-# SPECIALIZE addScaled :: ScaledPoly Double -> ScaledPoly Double ->
ScaledPoly Double #-}
(~+) :: (Unbox a, Fractional a) =>
BernsteinPoly a -> BernsteinPoly a -> BernsteinPoly a
a ~+ b = fromScaled $ addScaled (toScaled a) (toScaled b)
{-# INLINE (~+) #-}
subScaled :: (Unbox a, Num a) => ScaledPoly a -> ScaledPoly a -> ScaledPoly a
subScaled ba@(ScaledPoly a) bb@(ScaledPoly b)
| la < lb = ScaledPoly $
V.zipWith (-) (scaledCoeffs $ degreeElevateScaled ba $ lb-la) b
| la > lb = ScaledPoly $
V.zipWith (-) a (scaledCoeffs $ degreeElevateScaled bb $ la-lb)
| otherwise = ScaledPoly $ V.zipWith (-) a b
where la = V.length a
lb = V.length b
{-# SPECIALIZE subScaled :: ScaledPoly Double -> ScaledPoly Double ->
ScaledPoly Double #-}
(~-) :: (Unbox a, Fractional a) =>
BernsteinPoly a -> BernsteinPoly a -> BernsteinPoly a
(toScaled -> a) ~- (toScaled -> b) = fromScaled $ subScaled a b
{-# INLINE (~-) #-}
(*~) :: (Unbox a, Num a) => a -> BernsteinPoly a -> BernsteinPoly a
a *~ (BernsteinPoly v) = BernsteinPoly (V.map (*a) v)
{-# INLINE (*~) #-}