{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Numeric.AD.Mode.Reverse
( Reverse, auto
, grad
, grad'
, gradWith
, gradWith'
, jacobian
, jacobian'
, jacobianWith
, jacobianWith'
, hessian
, hessianF
, diff
, diff'
, diffF
, diffF'
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Data.Traversable (Traversable)
#endif
import Data.Functor.Compose
import Data.Reflection (Reifies)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Reverse
import Numeric.AD.Mode
grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a
grad f as = reifyTape (snd bds) $ \p -> unbind vs $! partialArrayOf p bds $! f vs where
(vs, bds) = bind as
{-# INLINE grad #-}
grad' :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> (a, f a)
grad' f as = reifyTape (snd bds) $ \p -> case f vs of
r -> (primal r, unbind vs $! partialArrayOf p bds $! r)
where (vs, bds) = bind as
{-# INLINE grad' #-}
gradWith :: (Traversable f, Num a) => (a -> a -> b) -> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f b
gradWith g f as = reifyTape (snd bds) $ \p -> unbindWith g vs $! partialArrayOf p bds $! f vs
where (vs,bds) = bind as
{-# INLINE gradWith #-}
gradWith' :: (Traversable f, Num a) => (a -> a -> b) -> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> (a, f b)
gradWith' g f as = reifyTape (snd bds) $ \p -> case f vs of
r -> (primal r, unbindWith g vs $! partialArrayOf p bds $! r)
where (vs, bds) = bind as
{-# INLINE gradWith' #-}
jacobian :: (Traversable f, Functor g, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a)) -> f a -> g (f a)
jacobian f as = reifyTape (snd bds) $ \p -> unbind vs . partialArrayOf p bds <$> f vs where
(vs, bds) = bind as
{-# INLINE jacobian #-}
jacobian' :: (Traversable f, Functor g, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a)) -> f a -> g (a, f a)
jacobian' f as = reifyTape (snd bds) $ \p ->
let row a = (primal a, unbind vs $! partialArrayOf p bds $! a)
in row <$> f vs
where (vs, bds) = bind as
{-# INLINE jacobian' #-}
jacobianWith :: (Traversable f, Functor g, Num a) => (a -> a -> b) -> (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a)) -> f a -> g (f b)
jacobianWith g f as = reifyTape (snd bds) $ \p -> unbindWith g vs . partialArrayOf p bds <$> f vs where
(vs, bds) = bind as
{-# INLINE jacobianWith #-}
jacobianWith' :: (Traversable f, Functor g, Num a) => (a -> a -> b) -> (forall s. Reifies s Tape => f (Reverse s a) -> g (Reverse s a)) -> f a -> g (a, f b)
jacobianWith' g f as = reifyTape (snd bds) $ \p ->
let row a = (primal a, unbindWith g vs $! partialArrayOf p bds $! a)
in row <$> f vs
where (vs, bds) = bind as
{-# INLINE jacobianWith' #-}
diff :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a) -> a -> a
diff f a = reifyTape 1 $ \p -> derivativeOf p $! f (var a 0)
{-# INLINE diff #-}
diff' :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a) -> a -> (a, a)
diff' f a = reifyTape 1 $ \p -> derivativeOf' p $! f (var a 0)
{-# INLINE diff' #-}
diffF :: (Functor f, Num a) => (forall s. Reifies s Tape => Reverse s a -> f (Reverse s a)) -> a -> f a
diffF f a = reifyTape 1 $ \p -> derivativeOf p <$> f (var a 0)
{-# INLINE diffF #-}
diffF' :: (Functor f, Num a) => (forall s. Reifies s Tape => Reverse s a -> f (Reverse s a)) -> a -> f (a, a)
diffF' f a = reifyTape 1 $ \p -> derivativeOf' p <$> f (var a 0)
{-# INLINE diffF' #-}
hessian :: (Traversable f, Num a) => (forall s s'. (Reifies s Tape, Reifies s' Tape) => f (On (Reverse s (Reverse s' a))) -> (On (Reverse s (Reverse s' a)))) -> f a -> f (f a)
hessian f = jacobian (grad (off . f . fmap On))
{-# INLINE hessian #-}
hessianF :: (Traversable f, Functor g, Num a) => (forall s s'. (Reifies s Tape, Reifies s' Tape) => f (On (Reverse s (Reverse s' a))) -> g (On (Reverse s (Reverse s' a)))) -> f a -> g (f (f a))
hessianF f = getCompose . jacobian (Compose . jacobian (fmap off . f . fmap On))
{-# INLINE hessianF #-}