{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Numeric.AD.Mode.Reverse.Double
( ReverseDouble, auto
, grad
, grad'
, gradWith
, gradWith'
, jacobian
, jacobian'
, jacobianWith
, jacobianWith'
, hessian
, hessianF
, diff
, diff'
, diffF
, diffF'
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Data.Traversable (Traversable)
#endif
import Data.Functor.Compose
import Data.Reflection (Reifies)
import Numeric.AD.Internal.On
import qualified Numeric.AD.Internal.Reverse as R
import qualified Numeric.AD.Mode.Reverse as M
import Numeric.AD.Internal.Reverse.Double
import Numeric.AD.Mode
grad :: (Traversable f) => (forall s. Reifies s Tape => f (ReverseDouble s) -> ReverseDouble s) -> f Double -> f Double
grad f as = reifyTape (snd bds) $ \p -> unbind vs $! partialArrayOf p bds $! f vs where
(vs, bds) = bind as
{-# INLINE grad #-}
grad' :: (Traversable f) => (forall s. Reifies s Tape => f (ReverseDouble s) -> ReverseDouble s) -> f Double -> (Double, f Double)
grad' f as = reifyTape (snd bds) $ \p -> case f vs of
r -> (primal r, unbind vs $! partialArrayOf p bds $! r)
where (vs, bds) = bind as
{-# INLINE grad' #-}
gradWith :: (Traversable f) => (Double -> Double -> b) -> (forall s. Reifies s Tape => f (ReverseDouble s) -> ReverseDouble s) -> f Double -> f b
gradWith g f as = reifyTape (snd bds) $ \p -> unbindWith g vs $! partialArrayOf p bds $! f vs
where (vs,bds) = bind as
{-# INLINE gradWith #-}
gradWith' :: (Traversable f) => (Double -> Double -> b) -> (forall s. Reifies s Tape => f (ReverseDouble s) -> ReverseDouble s) -> f Double -> (Double, f b)
gradWith' g f as = reifyTape (snd bds) $ \p -> case f vs of
r -> (primal r, unbindWith g vs $! partialArrayOf p bds $! r)
where (vs, bds) = bind as
{-# INLINE gradWith' #-}
jacobian :: (Traversable f, Functor g) => (forall s. Reifies s Tape => f (ReverseDouble s) -> g (ReverseDouble s)) -> f Double -> g (f Double)
jacobian f as = reifyTape (snd bds) $ \p -> unbind vs . partialArrayOf p bds <$> f vs where
(vs, bds) = bind as
{-# INLINE jacobian #-}
jacobian' :: (Traversable f, Functor g) => (forall s. Reifies s Tape => f (ReverseDouble s) -> g (ReverseDouble s)) -> f Double -> g (Double, f Double)
jacobian' f as = reifyTape (snd bds) $ \p ->
let row a = (primal a, unbind vs $! partialArrayOf p bds $! a)
in row <$> f vs
where (vs, bds) = bind as
{-# INLINE jacobian' #-}
jacobianWith :: (Traversable f, Functor g) => (Double -> Double -> b) -> (forall s. Reifies s Tape => f (ReverseDouble s) -> g (ReverseDouble s)) -> f Double -> g (f b)
jacobianWith g f as = reifyTape (snd bds) $ \p -> unbindWith g vs . partialArrayOf p bds <$> f vs where
(vs, bds) = bind as
{-# INLINE jacobianWith #-}
jacobianWith' :: (Traversable f, Functor g) => (Double -> Double -> b) -> (forall s. Reifies s Tape => f (ReverseDouble s) -> g (ReverseDouble s)) -> f Double -> g (Double, f b)
jacobianWith' g f as = reifyTape (snd bds) $ \p ->
let row a = (primal a, unbindWith g vs $! partialArrayOf p bds $! a)
in row <$> f vs
where (vs, bds) = bind as
{-# INLINE jacobianWith' #-}
diff :: (forall s. Reifies s Tape => ReverseDouble s -> ReverseDouble s) -> Double -> Double
diff f a = reifyTape 1 $ \p -> derivativeOf p $! f (var a 0)
{-# INLINE diff #-}
diff' :: (forall s. Reifies s Tape => ReverseDouble s -> ReverseDouble s) -> Double -> (Double, Double)
diff' f a = reifyTape 1 $ \p -> derivativeOf' p $! f (var a 0)
{-# INLINE diff' #-}
diffF :: (Functor f) => (forall s. Reifies s Tape => ReverseDouble s -> f (ReverseDouble s)) -> Double -> f Double
diffF f a = reifyTape 1 $ \p -> derivativeOf p <$> f (var a 0)
{-# INLINE diffF #-}
diffF' :: (Functor f) => (forall s. Reifies s Tape => ReverseDouble s -> f (ReverseDouble s)) -> Double -> f (Double, Double)
diffF' f a = reifyTape 1 $ \p -> derivativeOf' p <$> f (var a 0)
{-# INLINE diffF' #-}
hessian :: (Traversable f) => (forall s s'. (Reifies s R.Tape, Reifies s' Tape) => f (On (R.Reverse s (ReverseDouble s'))) -> (On (R.Reverse s (ReverseDouble s')))) -> f Double -> f (f Double)
hessian f = jacobian (M.grad (off . f . fmap On))
{-# INLINE hessian #-}
hessianF :: (Traversable f, Functor g) => (forall s s'. (Reifies s R.Tape, Reifies s' Tape) => f (On (R.Reverse s (ReverseDouble s'))) -> g (On (R.Reverse s (ReverseDouble s')))) -> f Double -> g (f (f Double))
hessianF f = getCompose . jacobian (Compose . M.jacobian (fmap off . f . fmap On))
{-# INLINE hessianF #-}