{-# LANGUAGE BangPatterns #-} module Bio.Util.AD2 ( AD2(..), paramVector2 ) where import qualified Data.Vector.Unboxed as U -- | Simple forward-mode AD to get a scalar valued function -- with gradient and Hessian. data AD2 = C2 !Double | D2 !Double !(U.Vector Double) !(U.Vector Double) instance Show AD2 where show (C2 x) = show x show (D2 x y z) = show x ++ " " ++ show (U.toList y) ++ " " ++ show [ U.toList (U.slice i d z) | i <- [0, d .. d*d-1] ] where d = U.length y instance Eq AD2 where C2 x == C2 y = x == y C2 x == D2 y _ _ = x == y D2 x _ _ == C2 y = x == y D2 x _ _ == D2 y _ _ = x == y instance Ord AD2 where C2 x `compare` C2 y = x `compare` y C2 x `compare` D2 y _ _ = x `compare` y D2 x _ _ `compare` C2 y = x `compare` y D2 x _ _ `compare` D2 y _ _ = x `compare` y instance Num AD2 where {-# INLINE (+) #-} C2 x + C2 y = C2 (x+y) C2 x + D2 y v h = D2 (x+y) v h D2 x u g + C2 y = D2 (x+y) u g D2 x u g + D2 y v h = D2 (x+y) (U.zipWith (+) u v) (U.zipWith (+) g h) {-# INLINE (-) #-} C2 x - C2 y = C2 (x-y) C2 x - D2 y v h = D2 (x-y) (U.map negate v) (U.map negate h) D2 x u g - C2 y = D2 (x-y) u g D2 x u g - D2 y v h = D2 (x-y) (U.zipWith (-) u v) (U.zipWith (-) g h) {-# INLINE (*) #-} C2 x * C2 y = C2 (x*y) C2 x * D2 y v h = D2 (x*y) (U.map (x*) v) (U.map (x*) h) D2 x u g * C2 y = D2 (x*y) (U.map (y*) u) (U.map (y*) g) D2 x u g * D2 y v h = D2 (x*y) grad hess where grad = U.zipWith (+) (U.map (x*) v) (U.map (y*) u) hess = U.zipWith (+) (U.zipWith (+) (U.map (x*) h) (U.map (y*) g)) (U.zipWith (+) (cross u v) (cross v u)) {-# INLINE negate #-} negate (C2 x) = C2 (negate x) negate (D2 x u g) = D2 (negate x) (U.map negate u) (U.map negate g) {-# INLINE fromInteger #-} fromInteger = C2 . fromInteger {-# INLINE abs #-} abs (C2 x) = C2 (abs x) abs (D2 x u g) | x < 0 = D2 (negate x) (U.map negate u) (U.map negate g) | otherwise = D2 x u g {-# INLINE signum #-} signum (C2 x) = C2 (signum x) signum (D2 x _ _) = C2 (signum x) instance Fractional AD2 where {-# INLINE (/) #-} C2 x / C2 y = C2 (x/y) D2 x u g / C2 y = D2 (x*z) (U.map (z*) u) (U.map (z*) g) where z = recip y x / y = x * recip y {-# INLINE recip #-} recip = liftF recip (\x -> - recip (sqr x)) (\x -> 2 * recip (cube x)) {-# INLINE fromRational #-} fromRational = C2 . fromRational instance Floating AD2 where {-# INLINE pi #-} pi = C2 pi {-# INLINE exp #-} exp = liftF exp exp exp {-# INLINE sqrt #-} sqrt = liftF sqrt (\x -> recip $ 2 * sqrt x) (\x -> - recip (sqrt (cube x))) {-# INLINE log #-} log = liftF log recip (\x -> - recip (sqr x)) sin = liftF sin cos (negate . sin) cos = liftF cos (negate . sin) (negate . cos) sinh = liftF sinh cosh sinh cosh = liftF cosh sinh cosh tan = liftF tan (\x -> recip (sqr (cos x))) (\x -> 2 * tan x / sqr (cos x)) tanh = liftF tanh (\x -> recip (sqr (cosh x))) (\x -> -2 * tanh x / sqr (cosh x)) asin = liftF asin (\x -> recip (sqrt (1 - sqr x))) (\x -> x / sqrt (cube (1 - sqr x))) acos = liftF acos (\x -> - recip (sqrt (1 - sqr x))) (\x -> -x / sqrt (cube (1 - sqr x))) asinh = liftF asinh (\x -> recip (sqrt (sqr x + 1))) (\x -> -x / sqrt (cube (sqr x + 1))) acosh = liftF acosh (\x -> - recip (sqrt (sqr x - 1))) (\x -> x / sqrt (cube (sqr x - 1))) atan = liftF atan (\x -> recip (1 + sqr x)) (\x -> -2 * x / sqr (1 + sqr x)) atanh = liftF atanh (\x -> recip (1 - sqr x)) (\x -> 2 * x / sqr (1 - sqr x)) {-# INLINE sqr #-} sqr :: Double -> Double sqr x = x * x {-# INLINE cube #-} cube :: Double -> Double cube x = x * x * x {-# INLINE liftF #-} liftF :: (Double -> Double) -> (Double -> Double) -> (Double -> Double) -> AD2 -> AD2 liftF f _ _ (C2 x) = C2 (f x) liftF f f' f'' (D2 x v g) = D2 (f x) (U.map (* f' x) v) hess where hess = U.zipWith (+) (U.map (* f' x) g) (U.map (* f'' x) (cross v v)) {-# INLINE cross #-} cross :: U.Vector Double -> U.Vector Double -> U.Vector Double cross u v = U.concatMap (\dy -> U.map (dy*) u) v {-# INLINE paramVector2 #-} paramVector2 :: [Double] -> [AD2] paramVector2 xs = [ D2 x (U.generate l (\j -> if i == j then 1 else 0)) nil | (i,x) <- zip [0..] xs ] where l = length xs ; nil = U.replicate (l*l) 0