module Data.Number.IReal.FAD where
import Data.Number.IReal.Scalable
import Data.Number.IReal.Powers
newtype Dif a = D [a] deriving Show
con, var :: Num a => a -> Dif a
con c = D [c]
var x = D [x,1]
mkDif :: a -> Dif a -> Dif a
mkDif x (D xs) = D (x:xs)
val :: Num a => Dif a -> a
val (D []) = 0
val (D (x:_)) = x
fromDif :: Num a => Dif a -> [a]
fromDif (D xs) = xs
unDif :: (Num a, Num b) => (Dif a -> Dif b) -> a -> b
unDif f = val . f . var
df :: Int -> Dif a -> Dif a
df n (D xs) = D (drop n xs)
deriv :: (Num a, Num b) => Int -> (Dif a -> Dif b) -> a -> b
deriv n f = unDif (df n . f)
derivs :: (Num a, Num b) => (Dif a -> Dif b) -> a -> [b]
derivs f = fromDif . f . var
chain, rchain :: Num a => (a -> a) -> (Dif a -> Dif a) -> Dif a -> Dif a
chain f f' g = mkDif (f (val g)) (df 1 g * f' g)
rchain f f' g = r where r = mkDif (f (val g)) (df 1 g * f' r)
r2chain :: Num a => (a -> a) -> (a -> a) -> (Dif a -> Dif a) ->
(Dif a -> Dif a) -> Dif a -> Dif a
r2chain f1 f2 f1' f2' g = x
where g' = df 1 g
x = mkDif (f1 (val g)) (g' * f1' y)
y = mkDif (f2 (val g)) (g' * f2' x)
instance VarPrec a => VarPrec (Dif a) where
precB b (D xs) = D (map (precB b) xs)
instance (Num a, Eq a) => Eq (Dif a) where
x==y = val x == val y
instance (Num a, Ord a) => Ord (Dif a) where
compare x y = compare (val x) (val y)
instance Num a => Num (Dif a) where
x + y = mkDif (val x + val y) (df 1 x + df 1 y)
x * y = D (convs (fromDif x) (fromDif y))
abs = chain abs signum
negate = chain negate (const (1))
signum = chain signum (const 0)
fromInteger = con . fromInteger
instance (Fractional a, Powers a) => Fractional (Dif a) where
recip g = r where r = mkDif (recip (val g)) (df 1 g * sq r)
fromRational = con . fromRational
instance (Floating a, Powers a) => Floating (Dif a) where
pi = con pi
exp = rchain exp id
log = chain log recip
sqrt = rchain sqrt (recip . (*2))
sin = r2chain sin cos id negate
cos = r2chain cos sin negate id
tan = rchain tan ((1+) . sq)
asin = chain asin (recip . sqrt . (1) . sq)
acos = chain acos (negate . recip . sqrt . (1) . sq)
atan = chain atan (recip . (1+) . sq)
sinh = r2chain sinh cosh id id
cosh = r2chain cosh sinh id id
asinh = chain asinh (recip . sqrt . (1+) . sq)
acosh = chain acosh (recip . sqrt . (\x -> x1) . sq)
atanh = chain atanh (recip . (1) . sq)
instance (Num a, Powers a) => Powers (Dif a) where
sq = chain sq (2*)
pow x 0 = con 1
pow x n = chain (flip pow n) ((fromIntegral n *) . flip pow (n1)) x
instance Real a => Real (Dif a) where
toRational = toRational . val
instance (Powers a, RealFrac a) => RealFrac (Dif a) where
properFraction x = (i, x fromIntegral i) where (i, _) = properFraction (val x)
truncate = truncate . val
round = round . val
ceiling = ceiling . val
floor = floor . val
instance ( Powers a, RealFloat a) => RealFloat (Dif a) where
floatRadix = floatRadix . val
floatDigits = floatDigits . val
floatRange = floatRange . val
exponent = exponent . val
scaleFloat n (D xs) = D (map (scaleFloat n) xs)
isNaN = isNaN . val
isInfinite = isInfinite . val
isDenormalized = isDenormalized . val
isNegativeZero = isNegativeZero . val
isIEEE = isIEEE . val
decodeFloat = decodeFloat . val
encodeFloat m e = con (encodeFloat m e)
convs [] _ = []
convs (a:as) bs = convs' [1] [a] as bs
where convs' _ _ _ [] = []
convs' ps ars [] bs = sumProd3 ps ars bs : convs'' (next' ps) ars bs
convs' ps ars (a:as) bs = sumProd3 ps ars bs : convs' (next ps) (a:ars) as bs
convs'' ps ars [_] = []
convs'' ps ars (_:bs) = sumProd3 ps ars bs : convs'' (next' ps) ars bs
next xs = 1 : zipWith (+) xs (tail xs) ++ [1]
next' xs = zipWith (+) xs (tail xs) ++ [1]
sumProd3 as bs cs = sum (zipWith3 (\x y z -> x*y*z) as bs cs)