module Math.BernsteinPoly
(BernsteinPoly(..), bernsteinSubsegment, listToBernstein, zeroPoly, (~*), (*~), (~+),
(~-), degreeElevate, bernsteinSplit, bernsteinEval,
bernsteinEvalDerivs, bernsteinDeriv)
where
import Data.List
data BernsteinPoly = BernsteinPoly {
bernsteinDegree :: !Int,
bernsteinCoeffs :: ![Double] }
deriving Show
infixl 7 ~*, *~
infixl 6 ~+, ~-
listToBernstein :: [Double] -> BernsteinPoly
listToBernstein [] = zeroPoly
listToBernstein l = BernsteinPoly (length l 1) l
zeroPoly :: BernsteinPoly
zeroPoly = BernsteinPoly 0 [0]
bernsteinSubsegment :: BernsteinPoly -> Double -> Double -> BernsteinPoly
bernsteinSubsegment b t1 t2
| t1 > t2 = bernsteinSubsegment b t2 t1
| otherwise = snd $ flip bernsteinSplit (t1/t2) $
fst $ bernsteinSplit b t2
(~*) :: BernsteinPoly -> BernsteinPoly -> BernsteinPoly
(BernsteinPoly la a) ~* (BernsteinPoly lb b) =
BernsteinPoly (la+lb) $
zipWith (flip (/)) (binCoeff (la + lb)) $
init $ map sum $
zipWith (zipWith (*)) (repeat a') (down b') ++
zipWith (zipWith (*)) (tail $ tails a') (repeat $ reverse b')
where down l = tail $ scanl (flip (:)) [] l
a' = zipWith (*) a (binCoeff la)
b' = zipWith (*) b (binCoeff lb)
binCoeff :: Int -> [Double]
binCoeff n = map fromIntegral $
scanl (\x m -> x * (nm+1) `quot` m) 1 [1..n]
degreeElevate :: BernsteinPoly -> Int -> BernsteinPoly
degreeElevate b 0 = b
degreeElevate (BernsteinPoly lp p) times =
degreeElevate (BernsteinPoly (lp+1) (head p:inner p 1)) (times1)
where
n = fromIntegral lp
inner [] _ = error "empty bernstein coefficients"
inner [a] _ = [a]
inner (a:b:rest) i =
(i*a/(n+1) + b*(1 i/(n+1)))
: inner (b:rest) (i+1)
bernsteinEval :: BernsteinPoly -> Double -> Double
bernsteinEval (BernsteinPoly lp [b]) _ = b
bernsteinEval (BernsteinPoly lp (b':bs)) t = go t n (b'*u) 2 bs
where u = 1t
n = fromIntegral lp
go !tn !bc !tmp _ [b] = tmp + tn*bc*b
go !tn !bc !tmp !i (b:rest) =
go (tn*t)
(bc*(ni+1)/i)
((tmp + tn*bc*b)*u)
(i+1)
rest
bernsteinEvalDerivs :: BernsteinPoly -> Double -> [Double]
bernsteinEvalDerivs b t
| bernsteinDegree b == 0 = [bernsteinEval b t]
| otherwise = bernsteinEval b t :
bernsteinEvalDerivs (bernsteinDeriv b) t
bernsteinDeriv :: BernsteinPoly -> BernsteinPoly
bernsteinDeriv (BernsteinPoly 0 _) = zeroPoly
bernsteinDeriv (BernsteinPoly lp p) =
BernsteinPoly (lp1) $
map (* fromIntegral lp) $ zipWith () (tail p) p
bernsteinSplit :: BernsteinPoly -> Double -> (BernsteinPoly, BernsteinPoly)
bernsteinSplit (BernsteinPoly lp p) t =
(BernsteinPoly lp $ map head controls,
BernsteinPoly lp $ reverse $ map last controls)
where
interp a b = (1t)*a + t*b
terp [_] = []
terp l = let ctrs = zipWith interp l (tail l)
in ctrs : terp ctrs
controls = p:terp p
(~+) :: BernsteinPoly -> BernsteinPoly -> BernsteinPoly
ba@(BernsteinPoly la a) ~+ bb@(BernsteinPoly lb b)
| la < lb = BernsteinPoly lb $
zipWith (+) (bernsteinCoeffs $ degreeElevate ba $ lbla) b
| la > lb = BernsteinPoly la $
zipWith (+) a (bernsteinCoeffs $ degreeElevate bb $ lalb)
| otherwise = BernsteinPoly la $
zipWith (+) a b
(~-) :: BernsteinPoly -> BernsteinPoly -> BernsteinPoly
ba@(BernsteinPoly la a) ~- bb@(BernsteinPoly lb b)
| la < lb = BernsteinPoly lb $
zipWith () (bernsteinCoeffs $ degreeElevate ba (lbla)) b
| la > lb = BernsteinPoly la $
zipWith () a (bernsteinCoeffs $ degreeElevate bb (lalb))
| otherwise = BernsteinPoly la $
zipWith () a b
(*~) :: Double -> BernsteinPoly -> BernsteinPoly
a *~ (BernsteinPoly lb b) = BernsteinPoly lb (map (*a) b)