module Numeric.AD.Newton
(
findZero
, findZeroM
, inverse
, inverseM
, fixedPoint
, fixedPointM
, extremum
, extremumM
, gradientDescent
, gradientDescentM
, gradientAscent
, gradientAscentM
, AD(..)
, Mode(..)
) where
import Prelude hiding (all)
import Control.Monad (liftM)
import Data.MList
import Numeric.AD.Classes
import Numeric.AD.Internal
import Data.Foldable (all)
import Data.Traversable (Traversable)
import Numeric.AD.Forward (diff, diff', diffM, diffM')
import Numeric.AD.Reverse (gradWith', gradWithM')
import Numeric.AD.Internal.Composition
findZero :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
findZero f = go
where
go x = x : go (x y/y')
where
(y,y') = diff' f x
findZeroM :: (Monad m, Fractional a) => (forall s. Mode s => AD s a -> m (AD s a)) -> a -> MList m a
findZeroM f x0 = MList (go x0)
where
go x = return $
MCons x $
MList $ do
(y,y') <- diffM' f x
go (x y/y')
inverse :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> a -> [a]
inverse f x0 y = findZero (\x -> f x lift y) x0
inverseM :: (Monad m, Fractional a) => (forall s. Mode s => AD s a -> m (AD s a)) -> a -> a -> MList m a
inverseM f x0 y = findZeroM (\x -> subtract (lift y) `liftM` f x) x0
fixedPoint :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
fixedPoint f = findZero (\x -> f x x)
fixedPointM :: (Monad m, Fractional a) => (forall s. Mode s => AD s a -> m (AD s a)) -> a -> MList m a
fixedPointM f = findZeroM (\x -> subtract x `liftM` f x)
extremum :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
extremum f = findZero (diff (decompose . f . compose))
extremumM :: (Monad m, Fractional a) => (forall s. Mode s => AD s a -> m (AD s a)) -> a -> MList m a
extremumM f = findZeroM (diffM (liftM decompose . f . compose))
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
gradientDescent f x0 = go x0 fx0 xgx0 0.1 (0 :: Int)
where
(fx0, xgx0) = gradWith' (,) f x0
go x fx xgx !eta !i
| eta == 0 = []
| fx1 > fx = go x fx xgx (eta/2) 0
| zeroGrad xgx = []
| otherwise = x1 : if i == 10
then go x1 fx1 xgx1 (eta*2) 0
else go x1 fx1 xgx1 eta (i+1)
where
zeroGrad = all (\(_,g) -> g == 0)
x1 = fmap (\(xi,gxi) -> xi eta * gxi) xgx
(fx1, xgx1) = gradWith' (,) f x1
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
gradientAscent f = gradientDescent (negate . f)
gradientDescentM :: (Traversable f, Monad m, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> m (AD s a)) -> f a -> MList m (f a)
gradientDescentM f x0 = MList $ do
(fx0, xgx0) <- gradWithM' (,) f x0
go x0 fx0 xgx0 0.1 (0 :: Int)
where
go x fx xgx !eta !i
| eta == 0 = return MNil
| otherwise = do
(fx1, xgx1) <- gradWithM' (,) f x1
case () of
_ | fx1 > fx -> go x fx xgx (eta/2) 0
| zeroGrad xgx -> return MNil
| otherwise -> return $
MCons x1 $
MList $
if i == 10
then go x1 fx1 xgx1 (eta*2) 0
else go x1 fx1 xgx1 eta (i+1)
where
x1 = fmap (\(xi,gxi) -> xi eta * gxi) xgx
zeroGrad = all (\(_,g) -> g == 0)
gradientAscentM :: (Traversable f, Monad m, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> m (AD s a)) -> f a -> MList m (f a)
gradientAscentM f = gradientDescentM (liftM negate . f)