{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.AD.Newton.Double
(
findZero
, findZeroNoEq
, inverse
, inverseNoEq
, fixedPoint
, fixedPointNoEq
, extremum
, extremumNoEq
, conjugateGradientDescent
, conjugateGradientAscent
) where
import Data.Foldable (all, sum)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.Forward.Double (ForwardDouble)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, grad)
import qualified Numeric.AD.Rank1.Newton.Double as Rank1
import Prelude hiding (all, mapM, sum)
findZero :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
findZero f = Rank1.findZero (runAD.f.AD)
{-# INLINE findZero #-}
findZeroNoEq :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
findZeroNoEq f = Rank1.findZeroNoEq (runAD.f.AD)
{-# INLINE findZeroNoEq #-}
inverse :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> Double -> [Double]
inverse f = Rank1.inverse (runAD.f.AD)
{-# INLINE inverse #-}
inverseNoEq :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> Double -> [Double]
inverseNoEq f = Rank1.inverseNoEq (runAD.f.AD)
{-# INLINE inverseNoEq #-}
fixedPoint :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
fixedPoint f = Rank1.fixedPoint (runAD.f.AD)
{-# INLINE fixedPoint #-}
fixedPointNoEq :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
fixedPointNoEq f = Rank1.fixedPointNoEq (runAD.f.AD)
{-# INLINE fixedPointNoEq #-}
extremum :: (forall s. AD s (On (Forward ForwardDouble)) -> AD s (On (Forward ForwardDouble))) -> Double -> [Double]
extremum f = Rank1.extremum (runAD.f.AD)
{-# INLINE extremum #-}
extremumNoEq :: (forall s. AD s (On (Forward ForwardDouble)) -> AD s (On (Forward ForwardDouble))) -> Double -> [Double]
extremumNoEq f = Rank1.extremumNoEq (runAD.f.AD)
{-# INLINE extremumNoEq #-}
conjugateGradientDescent
:: Traversable f
=> (forall s. Chosen s => f (Or s (On (Forward ForwardDouble)) (Kahn Double)) -> Or s (On (Forward ForwardDouble)) (Kahn Double))
-> f Double -> [f Double]
conjugateGradientDescent f = conjugateGradientAscent (negate . f)
{-# INLINE conjugateGradientDescent #-}
lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu f = runL . f . fmap L
rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu f = runR . f . fmap R
conjugateGradientAscent
:: Traversable f
=> (forall s. Chosen s => f (Or s (On (Forward ForwardDouble)) (Kahn Double)) -> Or s (On (Forward ForwardDouble)) (Kahn Double))
-> f Double -> [f Double]
conjugateGradientAscent f x0 = takeWhile (all (\a -> a == a)) (go x0 d0 d0 delta0)
where
dot x y = sum $ zipWithT (*) x y
d0 = Kahn.grad (rfu f) x0
delta0 = dot d0 d0
go xi _ri di deltai = xi : go xi1 ri1 di1 deltai1
where
ai = last $ take 20 $ Rank1.extremum (\a -> lfu f $ zipWithT (\x d -> auto x + a * auto d) xi di) 0
xi1 = zipWithT (\x d -> x + ai*d) xi di
ri1 = Kahn.grad (rfu f) xi1
deltai1 = dot ri1 ri1
bi1 = deltai1 / deltai
di1 = zipWithT (\r d -> r + bi1 * d) ri1 di
{-# INLINE conjugateGradientAscent #-}