{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE ParallelListComp #-}
module Numeric.AD.Newton
(
findZero
, findZeroNoEq
, inverse
, inverseNoEq
, fixedPoint
, fixedPointNoEq
, extremum
, extremumNoEq
, gradientDescent, constrainedDescent, CC(..), eval
, gradientAscent
, conjugateGradientDescent
, conjugateGradientAscent
, stochasticGradientDescent
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable, all, sum)
#else
import Data.Foldable (all, sum)
#endif
import Data.Reflection (Reifies)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Reverse (Reverse, Tape)
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Mode.Reverse as Reverse (gradWith, gradWith', grad')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, grad)
import qualified Numeric.AD.Rank1.Newton as Rank1
import Prelude hiding (all, mapM, sum)
findZero :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero f = Rank1.findZero (runAD.f.AD)
{-# INLINE findZero #-}
findZeroNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq f = Rank1.findZeroNoEq (runAD.f.AD)
{-# INLINE findZeroNoEq #-}
inverse :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse f = Rank1.inverse (runAD.f.AD)
{-# INLINE inverse #-}
inverseNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq f = Rank1.inverseNoEq (runAD.f.AD)
{-# INLINE inverseNoEq #-}
fixedPoint :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint f = Rank1.fixedPoint (runAD.f.AD)
{-# INLINE fixedPoint #-}
fixedPointNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq f = Rank1.fixedPointNoEq (runAD.f.AD)
{-# INLINE fixedPointNoEq #-}
extremum :: (Fractional a, Eq a) => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremum f = Rank1.extremum (runAD.f.AD)
{-# INLINE extremum #-}
extremumNoEq :: Fractional a => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremumNoEq f = Rank1.extremumNoEq (runAD.f.AD)
{-# INLINE extremumNoEq #-}
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientDescent f x0 = go x0 fx0 xgx0 0.1 (0 :: Int)
where
(fx0, xgx0) = Reverse.gradWith' (,) f x0
go x fx xgx !eta !i
| eta == 0 = []
| fx1 > fx = go x fx xgx (eta/2) 0
| zeroGrad xgx = []
| otherwise = x1 : if i == 10
then go x1 fx1 xgx1 (eta*2) 0
else go x1 fx1 xgx1 eta (i+1)
where
zeroGrad = all (\(_,g) -> g == 0)
x1 = fmap (\(xi,gxi) -> xi - eta * gxi) xgx
(fx1, xgx1) = Reverse.gradWith' (,) f x1
{-# INLINE gradientDescent #-}
data SEnv (f :: * -> *) a = SEnv { sValue :: a, origEnv :: f a }
deriving (Functor, Foldable, Traversable)
data CC f a where
CC :: forall f a. (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> CC f a
constrainedDescent :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> (forall s. Reifies s Tape => f (Reverse s a)
-> Reverse s a)
-> [CC f a]
-> f a
-> [(a,f a)]
constrainedDescent objF [] env =
map (\x -> (eval objF x, x)) (gradientDescent objF env)
constrainedDescent objF cs env =
let s0 = 1 + maximum [eval c env | CC c <- cs]
cs' = [CC (\(SEnv sVal rest) -> c rest - sVal) | CC c <- cs]
envS = SEnv s0 env
cc = constrainedConvex' (CC sValue) cs' envS ((<=0) . sValue)
in case dropWhile ((0 <) . fst) (take (2^(20::Int)) cc) of
[] -> []
(_,envFeasible) : _ ->
constrainedConvex' (CC objF) cs (origEnv envFeasible) (const True)
{-# INLINE constrainedDescent #-}
eval :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> a
eval f e = fst (grad' f e)
{-# INLINE eval #-}
constrainedConvex' :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a
-> [CC f a]
-> f a
-> (f a -> Bool)
-> [(a,f a)]
constrainedConvex' objF cs env term =
let os = map (mkOpt objF cs) tValues
envs = [(undefined,env)] :
[gD (snd $ last e) o
| o <- os
| e <- limEnvs
]
limEnvs = zipWith id nrSteps envs
in dropWhile (not . term . snd) (concat $ drop 1 limEnvs)
where
tValues = map realToFrac $ take 64 $ iterate (*2) (2 :: a)
nrSteps = [take 20 | _ <- [1..length tValues]] ++ [id]
gD e (CC f) = (eval f e, e) :
map (\x -> (eval f x, x)) (gradientDescent f e)
{-# INLINE constrainedConvex' #-}
mkOpt :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a -> [CC f a]
-> a -> CC f a
mkOpt (CC o) xs t = CC (\e -> o e + sum (map (\(CC c) -> iHat t c e) xs))
{-# INLINE mkOpt #-}
iHat :: forall a f. (Traversable f, RealFloat a, Floating a, Ord a)
=> a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
iHat t c e =
let r = c e
in if r >= 0 || isNaN r
then 1 / 0
else (-1 / auto t) * log( - (c e))
{-# INLINE iHat #-}
stochasticGradientDescent :: (Traversable f, Fractional a, Ord a)
=> (forall s. Reifies s Tape => f (Scalar a) -> f (Reverse s a) -> Reverse s a)
-> [f (Scalar a)]
-> f a
-> [f a]
stochasticGradientDescent errorSingle d0 x0 = go xgx0 0.001 dLeft
where
dLeft = tail $ cycle d0
xgx0 = Reverse.gradWith (,) (errorSingle (head d0)) x0
go xgx !eta d
| eta ==0 = []
| otherwise = x1 : go xgx1 eta (tail d)
where
x1 = fmap (\(xi, gxi) -> xi - eta * gxi) xgx
(_, xgx1) = Reverse.gradWith' (,) (errorSingle (head d)) x1
{-# INLINE stochasticGradientDescent #-}
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientAscent f = gradientDescent (negate . f)
{-# INLINE gradientAscent #-}
conjugateGradientDescent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent f = conjugateGradientAscent (negate . f)
{-# INLINE conjugateGradientDescent #-}
lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu f = runL . f . fmap L
rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu f = runR . f . fmap R
conjugateGradientAscent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent f x0 = takeWhile (all (\a -> a == a)) (go x0 d0 d0 delta0)
where
dot x y = sum $ zipWithT (*) x y
d0 = Kahn.grad (rfu f) x0
delta0 = dot d0 d0
go xi _ri di deltai = xi : go xi1 ri1 di1 deltai1
where
ai = last $ take 20 $ Rank1.extremum (\a -> lfu f $ zipWithT (\x d -> auto x + a * auto d) xi di) 0
xi1 = zipWithT (\x d -> x + ai*d) xi di
ri1 = Kahn.grad (rfu f) xi1
deltai1 = dot ri1 ri1
bi1 = deltai1 / deltai
di1 = zipWithT (\r d -> r + bi1 * d) ri1 di
{-# INLINE conjugateGradientAscent #-}