{-# OPTIONS_HADDOCK show-extensions #-}

{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}

{-|
Module      : Data.Utils.Analytic
Description : "analytic" values
Copyright   : (c) Lars Brünjes, 2016
License     : MIT
Maintainer  : brunjlar@gmail.com
Stability   : experimental
Portability : portable

This module defines the numeric class 'Analytic', "differentiable" functions @'Diff' f g@ 
and an adapted version of 'Numeric.AD.gradWith''.
-}

module Data.Utils.Analytic
    ( Analytic(..)
    , Diff(..)
    , Diff'
    , diff
    , gradWith'
    ) where

import           Control.Category
import           Data.MyPrelude
import           Data.Reflection             (Reifies)
import qualified Numeric.AD                  as AD                  
import           Numeric.AD.Internal.Reverse (Reverse, Tape)
import Prelude                               hiding (id, (.))

-- | Class 'Analytic' is a helper class for defining differentiable functions.
--
class (Floating a, Ord a) => Analytic a where

    fromDouble :: Double -> a

instance Analytic Double where

    fromDouble = id

instance Reifies s Tape => Analytic (Reverse s Double) where

    fromDouble = AD.auto

-- | Type @'Diff' f g@ can be thought of as the type of "differentiable" functions @f Double -> g Double@.
newtype Diff f g = Diff { runDiff :: forall a. Analytic a => f a -> g a }

instance Category Diff where

    id = Diff id

    Diff f . Diff g = Diff (f . g)

-- | Type @'Diff''@ can be thought of as the type of differentiable functions
--   @Double -> Double@.
type Diff' = forall a. Analytic a => a -> a

-- | Lifts a differentiable function by pointwise application.
--
diff :: Functor f => Diff' -> Diff f f
diff f = Diff (fmap f)

-- | Computes the gradient of an analytic function and combines it with the argument. 
--
-- >>> gradWith' (\_ d -> d) (Diff $ \[x, y] -> Identity $ x * x + 3 * y + 7) [2, 1]
-- (14.0,[4.0,3.0])
--
gradWith' :: Traversable t 
             => (Double -> Double -> a) -- ^ how to combine argument and gradient
             -> Diff t Identity         -- ^ differentiable function
             -> t Double                -- ^ function argument
             -> (Double, t a)           -- ^ function value and combination of argument and gradient
gradWith' c f = AD.gradWith' c (runIdentity . runDiff f)