module Data.Number.Dif(Dif, val, df, mkDif, dCon, dVar, deriv, unDif) where
data Dif a = D !a (Dif a) | C !a
dCon :: (Num a) => a -> Dif a
dCon x = C x
dVar :: (Num a, Eq a) => a -> Dif a
dVar x = D x 1
df :: (Num a, Eq a) => Dif a -> Dif a
df (D _ x') = x'
df (C _ ) = 0
val :: Dif a -> a
val (D x _) = x
val (C x ) = x
mkDif :: a -> Dif a -> Dif a
mkDif = D
deriv :: (Num a, Num b, Eq a, Eq b) => (Dif a -> Dif b) -> (a -> b)
deriv f = val . df . f . dVar
unDif :: (Num a, Eq a) => (Dif a -> Dif b) -> (a -> b)
unDif f = val . f . dVar
instance (Show a) => Show (Dif a) where
show x = show (val x) ++ "~~"
instance (Read a) => Read (Dif a) where
readsPrec p s = [(C x, s') | (x, s') <- readsPrec p s]
instance (Eq a) => Eq (Dif a) where
x == y = val x == val y
instance (Ord a) => Ord (Dif a) where
x `compare` y = val x `compare` val y
instance (Num a, Eq a) => Num (Dif a) where
(C x) + (C y) = C (x + y)
(C x) + (D y y') = D (x + y) y'
(D x x') + (C y) = D (x + y) x'
(D x x') + (D y y') = D (x + y) (x' + y')
(C x) (C y) = C (x y)
(C x) (D y y') = D (x y) (y')
(D x x') (C y) = D (x y) x'
(D x x') (D y y') = D (x y) (x' y')
(C 0) * _ = C 0
_ * (C 0) = C 0
(C x) * (C y) = C (x * y)
p@(C x) * (D y y') = D (x * y) (p * y')
(D x x') * q@(C y) = D (x * y) (x' * q)
p@(D x x') * q@(D y y') = D (x * y) (x' * q + p * y')
negate (C x) = C (negate x)
negate (D x x') = D (negate x) (negate x')
fromInteger i = C (fromInteger i)
abs (C x) = C (abs x)
abs p@(D x x') = D (abs x) (signum p * x')
signum (C x) = C (signum x)
signum (D x _) = C (signum x)
instance (Fractional a, Eq a) => Fractional (Dif a) where
recip (C x) = C (recip x)
recip (D x x') = ip
where ip = D (recip x) (x' * ip * ip)
fromRational r = C (fromRational r)
lift :: (Num a, Eq a) => [a -> a] -> Dif a -> Dif a
lift (f : _) (C x) = C (f x)
lift (f : f') p@(D x x') = D (f x) (x' * lift f' p)
lift _ _ = error "lift"
instance (Floating a, Eq a) => Floating (Dif a) where
pi = C pi
exp (C x) = C (exp x)
exp (D x x') = r where r = D (exp x) (x' * r)
log (C x) = C (log x)
log p@(D x x') = D (log x) (x' / p)
sqrt (C x) = C (sqrt x)
sqrt (D x x') = r where r = D (sqrt x) (x' / (2 * r))
sin = lift (cycle [sin, cos, negate . sin, negate . cos])
cos = lift (cycle [cos, negate . sin, negate . cos, sin])
acos (C x) = C (acos x)
acos p@(D x x') = D (acos x) (x' / sqrt(1 p*p))
asin (C x) = C (asin x)
asin p@(D x x') = D (asin x) ( x' / sqrt(1 p*p))
atan (C x) = C (atan x)
atan p@(D x x') = D (atan x) ( x' / (p*p 1))
sinh x = (exp x exp (x)) / 2
cosh x = (exp x + exp (x)) / 2
asinh x = log (x + sqrt (x*x + 1))
acosh x = log (x + sqrt (x*x 1))
atanh x = (log (1 + x) log (1 x)) / 2
instance (Real a) => Real (Dif a) where
toRational = toRational . val
instance (RealFrac a) => RealFrac (Dif a) where
properFraction x = (i, x fromIntegral i) where (i, _) = properFraction (val x)
truncate = truncate . val
round = round . val
ceiling = ceiling . val
floor = floor . val
instance (RealFloat a) => RealFloat (Dif a) where
floatRadix = floatRadix . val
floatDigits = floatDigits . val
floatRange = floatRange . val
exponent _ = 0
scaleFloat 0 x = x
isNaN = isNaN . val
isInfinite = isInfinite . val
isDenormalized = isDenormalized . val
isNegativeZero = isNegativeZero . val
isIEEE = isIEEE . val