{-# LANGUAGE BangPatterns #-} module Bio.Util.AD ( AD(..), paramVector, minimize , module Numeric.Optimization.Algorithms.HagerZhang05 , debugParameters, quietParameters ) where import Numeric.Optimization.Algorithms.HagerZhang05 import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Storable as V -- | Simple forward-mode AD to get a scalar valued function with gradient. data AD = C !Double | D !Double !(U.Vector Double) deriving Show instance Eq AD where C x == C y = x == y C x == D y _ = x == y D x _ == C y = x == y D x _ == D y _ = x == y instance Ord AD where C x `compare` C y = x `compare` y C x `compare` D y _ = x `compare` y D x _ `compare` C y = x `compare` y D x _ `compare` D y _ = x `compare` y instance Num AD where {-# INLINE (+) #-} C x + C y = C (x+y) C x + D y v = D (x+y) v D x u + C y = D (x+y) u D x u + D y v = D (x+y) (U.zipWith (+) u v) {-# INLINE (-) #-} C x - C y = C (x-y) C x - D y v = D (x-y) (U.map negate v) D x u - C y = D (x-y) u D x u - D y v = D (x-y) (U.zipWith (-) u v) {-# INLINE (*) #-} C x * C y = C (x*y) C x * D y v = D (x*y) (U.map (x*) v) D x u * C y = D (x*y) (U.map (y*) u) D x u * D y v = D (x*y) (U.zipWith (+) (U.map (x*) v) (U.map (y*) u)) {-# INLINE negate #-} negate (C x) = C (negate x) negate (D x u) = D (negate x) (U.map negate u) {-# INLINE fromInteger #-} fromInteger = C . fromInteger {-# INLINE abs #-} abs (C x) = C (abs x) abs (D x u) | x < 0 = D (negate x) (U.map negate u) | otherwise = D x u {-# INLINE signum #-} signum (C x) = C (signum x) signum (D x _) = C (signum x) instance Fractional AD where {-# INLINE (/) #-} C x / C y = C (x/y) D x u / C y = D (x*z) (U.map (z*) u) where z = recip y C x / D y v = D (x/y) (U.map (w*) v) where w = negate $ x * z * z ; z = recip y D x u / D y v = D (x/y) (U.zipWith (-) (U.map (z*) u) (U.map (w*) v)) where z = recip y ; w = x * z * z {-# INLINE recip #-} recip = liftF recip (\x -> - recip (x*x)) {-# INLINE fromRational #-} fromRational = C . fromRational instance Floating AD where {-# INLINE pi #-} pi = C pi {-# INLINE exp #-} exp = liftF exp exp {-# INLINE sqrt #-} sqrt = liftF sqrt $ \x -> recip (2 * sqrt x) {-# INLINE log #-} log = liftF log recip sin = liftF sin cos cos = liftF cos (negate . sin) sinh = liftF sinh cosh cosh = liftF cosh sinh tan = liftF tan $ \x -> recip (cos x * cos x) tanh = liftF tanh $ \x -> recip (cosh x * cosh x) asin = liftF asin $ \x -> recip (sqrt (1 - x * x)) acos = liftF acos $ \x -> - recip (sqrt (1 - x * x)) atan = liftF atan $ \x -> recip (1 + x * x) asinh = liftF asinh $ \x -> recip (sqrt (x * x + 1)) acosh = liftF acosh $ \x -> - recip (sqrt (x * x - 1)) atanh = liftF atanh $ \x -> recip (1 - x * x) {-# INLINE liftF #-} liftF :: (Double -> Double) -> (Double -> Double) -> AD -> AD liftF f _ (C x) = C (f x) liftF f g (D x u) = D (f x) (U.map (* g x) u) {-# INLINE paramVector #-} paramVector :: [Double] -> [AD] paramVector xs = [ D x (U.generate l (\j -> if i == j then 1 else 0)) | (i,x) <- zip [0..] xs ] where l = length xs {-# INLINE minimize #-} minimize :: Parameters -> Double -> ([AD] -> AD) -> U.Vector Double -> IO (V.Vector Double, Result, Statistics) minimize params eps func v0 = optimize params eps v0 (VFunction $ fst . combofn) (VGradient $ snd . combofn) (Just . VCombined $ combofn) where combofn parms = case func $ paramVector $ U.toList parms of D x g -> ( x, g ) C x -> ( x, U.replicate (U.length parms) 0 ) quietParameters :: Parameters quietParameters = defaultParameters { printFinal = False, verbose = Quiet, maxItersFac = 123 } debugParameters :: Parameters debugParameters = defaultParameters { verbose = Verbose }