module Numeric.AD.Newton
(
findZero
, inverse
, fixedPoint
, extremum
, gradientDescent
, gradientAscent
, conjugateGradientDescent
, conjugateGradientAscent
) where
import Prelude hiding (all, mapM, sum)
import Data.Functor
import Data.Foldable (all, sum)
import Data.Traversable
import Numeric.AD.Types
import Numeric.AD.Mode.Forward (diff, diff')
import Numeric.AD.Mode.Reverse (grad, gradWith')
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Composition
findZero :: (Fractional a, Eq a) => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
findZero f = go
where
go x = x : if y == 0 then [] else go (x y/y')
where
(y,y') = diff' f x
inverse :: (Fractional a, Eq a) => (forall s. Mode s => AD s a -> AD s a) -> a -> a -> [a]
inverse f x0 y = findZero (\x -> f x lift y) x0
fixedPoint :: (Fractional a, Eq a) => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
fixedPoint f = findZero (\x -> f x x)
extremum :: (Fractional a, Eq a) => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
extremum f = findZero (diff (decomposeMode . f . composeMode))
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
gradientDescent f x0 = go x0 fx0 xgx0 0.1 (0 :: Int)
where
(fx0, xgx0) = gradWith' (,) f x0
go x fx xgx !eta !i
| eta == 0 = []
| fx1 > fx = go x fx xgx (eta/2) 0
| zeroGrad xgx = []
| otherwise = x1 : if i == 10
then go x1 fx1 xgx1 (eta*2) 0
else go x1 fx1 xgx1 eta (i+1)
where
zeroGrad = all (\(_,g) -> g == 0)
x1 = fmap (\(xi,gxi) -> xi eta * gxi) xgx
(fx1, xgx1) = gradWith' (,) f x1
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
gradientAscent f = gradientDescent (negate . f)
conjugateGradientDescent :: (Traversable f, Fractional a, Ord a) =>
(forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
conjugateGradientDescent f x0 = go x0 d0 d0
where
dot x y = sum $ zipWithT (*) x y
d0 = negate <$> grad f x0
go xi ri di = xi : go xi1 ri1 di1
where
ai = last $ take 20 $ extremum (\a -> f $ zipWithT (\x d -> lift x + a * lift d) xi di) 0
xi1 = zipWithT (\x d -> x + ai*d) xi di
ri1 = negate <$> grad f xi1
bi1 = max 0 $ dot ri1 (zipWithT () ri1 ri) / dot ri1 ri1
di1 = zipWithT (\r d -> r * bi1*d) ri1 di
conjugateGradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
conjugateGradientAscent f = conjugateGradientDescent (negate . f)