module Numeric.AD.Newton
(
findZero
, findZeroNoEq
, inverse
, inverseNoEq
, fixedPoint
, fixedPointNoEq
, extremum
, extremumNoEq
, gradientDescent, constrainedDescent, CC(..), eval
, gradientAscent
, conjugateGradientDescent
, conjugateGradientAscent
, stochasticGradientDescent
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable, all, sum)
#else
import Data.Foldable (all, sum)
#endif
import Data.Reflection (Reifies)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Reverse (Reverse, Tape)
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Mode.Reverse as Reverse (gradWith, gradWith', grad')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, grad)
import qualified Numeric.AD.Rank1.Newton as Rank1
import Prelude hiding (all, mapM, sum)
findZero :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero f = Rank1.findZero (runAD.f.AD)
findZeroNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq f = Rank1.findZeroNoEq (runAD.f.AD)
inverse :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse f = Rank1.inverse (runAD.f.AD)
inverseNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq f = Rank1.inverseNoEq (runAD.f.AD)
fixedPoint :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint f = Rank1.fixedPoint (runAD.f.AD)
fixedPointNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq f = Rank1.fixedPointNoEq (runAD.f.AD)
extremum :: (Fractional a, Eq a) => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremum f = Rank1.extremum (runAD.f.AD)
extremumNoEq :: Fractional a => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremumNoEq f = Rank1.extremumNoEq (runAD.f.AD)
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientDescent f x0 = go x0 fx0 xgx0 0.1 (0 :: Int)
where
(fx0, xgx0) = Reverse.gradWith' (,) f x0
go x fx xgx !eta !i
| eta == 0 = []
| fx1 > fx = go x fx xgx (eta/2) 0
| zeroGrad xgx = []
| otherwise = x1 : if i == 10
then go x1 fx1 xgx1 (eta*2) 0
else go x1 fx1 xgx1 eta (i+1)
where
zeroGrad = all (\(_,g) -> g == 0)
x1 = fmap (\(xi,gxi) -> xi eta * gxi) xgx
(fx1, xgx1) = Reverse.gradWith' (,) f x1
data SEnv (f :: * -> *) a = SEnv { sValue :: a, origEnv :: f a }
deriving (Functor, Foldable, Traversable)
data CC f a where
CC :: forall f a. (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> CC f a
constrainedDescent :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> (forall s. Reifies s Tape => f (Reverse s a)
-> Reverse s a)
-> [CC f a]
-> f a
-> [(a,f a)]
constrainedDescent objF [] env =
map (\x -> (eval objF x, x)) (gradientDescent objF env)
constrainedDescent objF cs env =
let s0 = 1 + maximum [eval c env | CC c <- cs]
cs' = [CC (\(SEnv sVal rest) -> c rest sVal) | CC c <- cs]
envS = SEnv s0 env
cc = constrainedConvex' (CC sValue) cs' envS ((<=0) . sValue)
in case dropWhile ((0 <) . fst) (take (2^(20::Int)) cc) of
[] -> []
(_,envFeasible) : _ ->
constrainedConvex' (CC objF) cs (origEnv envFeasible) (const True)
eval :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> a
eval f e = fst (grad' f e)
constrainedConvex' :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a
-> [CC f a]
-> f a
-> (f a -> Bool)
-> [(a,f a)]
constrainedConvex' objF cs env term =
let os = map (mkOpt objF cs) tValues
envs = [(undefined,env)] :
[gD (snd $ last e) o
| o <- os
| e <- limEnvs
]
limEnvs = zipWith id nrSteps envs
in dropWhile (not . term . snd) (concat $ drop 1 limEnvs)
where
tValues = map realToFrac $ take 64 $ iterate (*2) (2 :: a)
nrSteps = [take 20 | _ <- [1..length tValues]] ++ [id]
gD e (CC f) = (eval f e, e) :
map (\x -> (eval f x, x)) (gradientDescent f e)
mkOpt :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a -> [CC f a]
-> a -> CC f a
mkOpt (CC o) xs t = CC (\e -> o e + sum (map (\(CC c) -> iHat t c e) xs))
iHat :: forall a f. (Traversable f, RealFloat a, Floating a, Ord a)
=> a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
iHat t c e =
let r = c e
in if r >= 0 || isNaN r
then 1 / 0
else (1 / auto t) * log( (c e))
stochasticGradientDescent :: (Traversable f, Fractional a, Ord a)
=> (forall s. Reifies s Tape => f (Scalar a) -> f (Reverse s a) -> Reverse s a)
-> [f (Scalar a)]
-> f a
-> [f a]
stochasticGradientDescent errorSingle d0 x0 = go xgx0 0.001 dLeft
where
dLeft = tail $ cycle d0
xgx0 = Reverse.gradWith (,) (errorSingle (head d0)) x0
go xgx !eta d
| eta ==0 = []
| otherwise = x1 : go xgx1 eta (tail d)
where
x1 = fmap (\(xi, gxi) -> xi eta * gxi) xgx
(_, xgx1) = Reverse.gradWith' (,) (errorSingle (head d)) x1
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientAscent f = gradientDescent (negate . f)
conjugateGradientDescent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent f = conjugateGradientAscent (negate . f)
lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu f = runL . f . fmap L
rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu f = runR . f . fmap R
conjugateGradientAscent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent f x0 = takeWhile (all (\a -> a == a)) (go x0 d0 d0 delta0)
where
dot x y = sum $ zipWithT (*) x y
d0 = Kahn.grad (rfu f) x0
delta0 = dot d0 d0
go xi _ri di deltai = xi : go xi1 ri1 di1 deltai1
where
ai = last $ take 20 $ Rank1.extremum (\a -> lfu f $ zipWithT (\x d -> auto x + a * auto d) xi di) 0
xi1 = zipWithT (\x d -> x + ai*d) xi di
ri1 = Kahn.grad (rfu f) xi1
deltai1 = dot ri1 ri1
bi1 = deltai1 / deltai
di1 = zipWithT (\r d -> r + bi1 * d) ri1 di