inf-backprop-0.1.0.0: Automatic differentiation and backpropagation.
Copyright(C) 2023 Alexey Tochin
LicenseBSD3 (see the file LICENSE)
MaintainerAlexey Tochin <Alexey.Tochin@gmail.com>
Safe HaskellSafe-Inferred
LanguageHaskell2010
Extensions
  • MonoLocalBinds
  • ScopedTypeVariables
  • TypeFamilies
  • GADTs
  • GADTSyntax
  • ConstraintKinds
  • InstanceSigs
  • DeriveFunctor
  • TypeSynonymInstances
  • FlexibleContexts
  • FlexibleInstances
  • ConstrainedClassMethods
  • MultiParamTypeClasses
  • KindSignatures
  • TupleSections
  • RankNTypes
  • ExplicitNamespaces
  • ExplicitForAll

InfBackprop.Tutorial

Description

Tutorial inf-backprop package.

Synopsis

    Quick start

    >>> :set -XNoImplicitPrelude
    >>> import Prelude (Float, fmap)
    >>> import InfBackprop (BackpropFunc, call, derivative, derivativeN, pow)
    

    We can define differentiable function

    \[ f(x) := x^2 \]

    as follows

    >>> smoothF = pow 2 :: BackpropFunc Float Float
    

    where pow is a power differentiable function and BackpropFunc :: * -> * -> * is a type for infinitely differentiable (smooth) functions. We can get the function values by call method like

    >>> f = call smoothF :: Float -> Float
    >>> fmap f [-3, -2, -1, 0, 1, 2, 3]
    [9.0,4.0,1.0,0.0,1.0,4.0,9.0]
    

    as well as the first derivative by derivative, which is

    \[ f'(x) = 2 \cdot x \]

    >>> df = derivative smoothF :: Float -> Float
    >>> fmap df [-3, -2, -1, 0, 1, 2, 3]
    [-6.0,-4.0,-2.0,0.0,2.0,4.0,6.0]
    

    or the second derivative

    \[ f''(x) = 2 \]

    >>> d2f = derivativeN 2 smoothF :: Float -> Float
    >>> fmap d2f [-3, -2, -1, 0, 1, 2, 3]
    [2.0,2.0,2.0,2.0,2.0,2.0,2.0]
    

    and so on.

    A composition of two functions like

    \[ g(x) := \log x^3 \]

    must be defined with the categorical composition (>>>) (or (<<<))

    >>> import InfBackprop (log)
    >>> import Control.Category ((>>>), (<<<))
    >>> smoothG = pow 3 >>> log
    

    For more complicated expressions, for example,

    \[ h(x) := x^2 + x^3 \]

    we use arrow notations (***), first and second as follows

    >>> import InfBackprop ((+), dup)
    >>> import Control.CatBifunctor ((***))
    
    >>> smoothH = dup >>> (pow 2 *** pow 3) >>> (+) :: BackpropFunc Float Float
    

    where

      dup :: BackpropFunc a (a, a)
    

    is differentiable function that simply splits the single implicit argument x into the tuple '(x, x)'. THis is needed path tje implicit x to two independent functions pow 2 and pow 3. The last

      (+) :: BackpropFunc (a, a) a
    

    operation transforms the pair of implicit arguments into their sum.

    Derivatives for symbolic expressions

    >>> import Prelude (($))
    >>> import Control.Category ((<<<))
    >>> import InfBackprop (BackpropFunc, call, derivative, derivativeN, sin, pow, (**), pow, setSecond, const)
    

    We use simple-expr package here.

    >>> import Debug.SimpleExpr.Expr (SimpleExpr, variable, simplify)
    

    For example a symbolic function

    \[ f(x) := \sin x^2 \]

    can be defined as follows

    >>> x = variable "x"
    >>> f = sin <<< pow 2 :: BackpropFunc SimpleExpr SimpleExpr
    

    see Tutorial for details. We can call the symbolic function like

    >>> call f x
    sin(x·x)
    

    and find the symbolic derivative

    \[ \frac{d}{d x} f(x) = \frac{d}{d x} \sin x^2 = 2\, x \cos x^2 \]

    as follows

    >>> simplify $ derivative f x
    cos(x·x)·(2·x)
    

    as well as the second and higher derivatives

    >>> simplify $ derivativeN 2 f x
    (((2·x)·-(sin(x·x)))·(2·x))+(2·cos(x·x))
    

    Symbolic expressions visualization

    The simple-expr package is equipped with a visulaisation tool that can be used to illustrate how the differentiation works.

    >>> import Control.Category ((<<<))
    >>> import InfBackprop (call, backpropExpr)
    >>> import Debug.SimpleExpr.Expr (SimpleExpr, variable, simplify)
    >>> import Debug.SimpleExpr.GraphUtils (exprToGraph)
    >>> import Data.Graph.VisualizeAlternative (plotDGraph)
    

    As a warm up consider a trivial composition of two functions

    \[ g(f(x)) \]

    is defined as

    >>> x = variable "x"
    >>> call (backpropExpr "g" <<< backpropExpr "f") x
    g(f(x))
    

    It can be plotted by

     plotExpr $ call (backpropExpr "g" <<< backpropExpr "f") x

    The graph for the first derivative can depicted by

     plotExpr $ simplify $ derivative (backpropExpr "g" <<< backpropExpr "f") x

    where simplify :: SimpleExpr -> SimpleExpr is a simple removal such things like *1 and +0.

    As well as the second derivative is straightforward

     plotExpr $ simplify $ derivativeN 2 (backpropExpr "g" <<< backpropExpr "f") x

    How it works

    The idea would be clear from the example of three functions composition

    \[ g(f(h(x))) \] with a focus on function f.

    Its first derivative over x is

    \[ g(f(h(x))). \]

    \[ h'(x) \cdot f'(h(x)) \cdot g'(f(h(x))). \]

    According to the backpropagation strategy, the order of the calculation should be as follows.

    1. Find h(x).
    2. Find f(h(x)).
    3. Find g(f(h(x))).
    4. Find the top derivative g'(f(h(x))).
    5. Find the next to the top derivative f'(h(x)).
    6. Multiply g'(f(h(x))) on f'(h(x)).
    7. Find the next derivative h'(x).
    8. Multiply the output of point 6 on h'(x).

    The generalization for longer composition is straightforward.

    All calculations related to the function f can be divided into two parts. We have to find f of h(x) first (forward step) and then the derivative f' of the same argument h(x) and multiply it on the derivative g'(f(h(x))) obtained during the similar calculations for g (backward step). Notice that the value of h(x) is reused on the backward step. To implement this, we define type Backprop (see the corresponding documentation for details).

    Declaring custom derivative

    >>> import Prelude (Float)
    >>> import qualified Prelude
    >>> import Control.Category ((>>>))
    >>> import InfBackprop ((*), negate, dup, BackpropFunc, Backprop(MkBackprop), second)
    

    As an illustrative example a differentiable version of cos numerical function can be defined as follows (see the documentation for Backprop for details)

      cos :: BackpropFunc Float Float
      cos = MkBackprop call' forward' backward' where
        call' :: Float -> Float
        call' = Prelude.cos
    
        forward' :: BackpropFunc Float (Float, Float)
        forward' = dup >>> first cos
    
        backward' :: BackpropFunc (Float, Float) Float
        backward' = second (sin >>> negate) >>> (*)
    
      sin :: BackpropFunc Float Float
      sin = ...
    

    Here we use Prelude implementation for ordinary cos function in call. The forward function is differentiable (which is needed for further derivatives) function with two output values. Roughly speaking forward is x -> (sin x, x). The first term of the tuple is just sin and the second terms x in the tuple is the value to be reused on the backward step. The backward is (dy, x) -> dy * (-cos x), where dy is the derivative found on the previous backward step and the second value is x stored by forward. We simply multiply with (*) the derivative dy on the derivative of sin that is -cos.

    The stored value is not necessary just x. It could be anything useful for the backward step, see for example the implementation for exp and the corresponding example below.

    Differentiation of monadic function

    Differentiable versions of monadic functions a -> m b can also be backpropagated. For example, consider a real-valued power function defined for positive real numbers. For a negative number, it returns Nothing, which is a signal to stop computing the derivative and return Nothing in the spirit of the behavior of the monad Maybe. For this purpose, we can use that the type Backprop type is defined for any category, not only for functions (->). In particular, we can try Backprop(Kleisli Maybe) instead of Backprop(->) from the previous sections.

    >>> import Prelude (Maybe, Maybe(Just, Nothing), ($), Ord, (>), Float)
    >>> import InfBackprop (Backprop(MkBackprop), derivative, dup, (*), linear, pureBackprop, first, second)
    >>> import Control.Arrow (Kleisli(Kleisli), runKleisli, (>>>))
    >>> import qualified NumHask as NH
    

    The functoin

     pureBackprop :: Monad m => Backprop (->) a b -> Backprop (Kleisli m) a b
    

    is to trivially lift an ordinary backpropagation functions to the monadic function type.

    Define the power function as follows

    >>> :{
     powR :: forall a. (Ord a, NH.ExpField a) =>
       a -> Backprop (Kleisli Maybe) a a
     powR p = MkBackprop call' forward' backward'
       where
         call' :: Kleisli Maybe a a
         call' = Kleisli $ \x -> if x > NH.zero
           then Just $ x NH.** p
           else Nothing
         --
         forward' :: Backprop (Kleisli Maybe) a (a, a)
         forward' = pureBackprop dup >>> first (powR p)
         --
         backward' :: Backprop (Kleisli Maybe) (a, a) a
         backward' = second der >>> pureBackprop (*) where
           der = powR (p NH.- NH.one) >>> pureBackprop (linear p)
    :}
    

    and calculate

    \[ \frac{d}{dx} x^{\frac12} = \frac{1}{2 \sqrt{x}} \]

    for x=4 and x=-4 like

    >>> runKleisli (derivative (powR 0.5)) (4 :: Float)
    Just 0.25
    >>> runKleisli (derivative (powR 0.5)) (-4 :: Float)
    Nothing
    

    Differentiation with logging

    Our objective now is to add logging to the derivative calculation. The type Backprop cat a b type is parametrized by a category cat, input a and output b. If cat is (->) the type is reduced to BackpropFunc we worked with above. To add logging to the calculation we shall replace (->) by MonadLogger m => Kleisli m. We will need the imports below

    >>> import Prelude (Integer, Float, ($), (+), (*))
    >>> import Control.Monad.Logger (runStdoutLoggingT, MonadLogger)
    >>> import Control.Arrow ((>>>), runKleisli, Kleisli)
    >>> import InfBackprop (derivative, loggingBackpropExpr)
    >>> import Debug.SimpleExpr.Expr (variable)
    >>> import Debug.LoggingBackprop (initUnaryFunc, initBinaryFunc, pureKleisli, exp, sin)
    

    where the module loggingBackpropExpr contains some useful functionality. For example, lifts for unary functions

     initUnaryFunc :: (Show a, Show b, MonadLogger m) => String -> (a -> b) -> Kleisli m a b
    

    and binary functions

     initBinaryFunc :: (Show a, Show b, Show c, MonadLogger m) => String -> (a -> b -> c) -> Kleisli m (a, b) c
    

    These two terms map given functions to Kleisli category terms, that allows logging during their execution.

    Let us first explain how it works with the following example.

    \[ f(x) = y \cdot 3 + y \cdot 4, \quad y = x + 2. \]

    This function can be defined as follows

    >>> :{
     fLogging :: MonadLogger m => Kleisli m Integer Integer
     fLogging =
       initUnaryFunc "+2" (+2) >>>
       (pureKleisli (\x -> (x, x))) >>>
       (initUnaryFunc "*3" (*3) *** initUnaryFunc "*4" (*4)) >>>
       initBinaryFunc "sum" (+)
    :}
    

    We run the calculation with x = 5 as follows

    >>> runStdoutLoggingT $ runKleisli fLogging 5
    [Info] Calculating +2 of 5 => 7
    [Info] Calculating *3 of 7 => 21
    [Info] Calculating *4 of 7 => 28
    [Info] Calculating sum of 21 and 28 => 49
    49
    

    We are now ready to consider an example with derivatives. Let us calculate a simple example as follows

    \[ \frac{d}{dx} \mathrm{f} (e^x) = e^x f'(e^x) \]

    We define symbolic function f by

     loggingBackpropExpr :: String -> BackpropFunc SimpleExpr SimpleExpr
    

    and the entire derivative is

    >>> runStdoutLoggingT $ runKleisli (derivative (exp >>> loggingBackpropExpr "f")) (variable "x")
    [Info] Calculating exp of x => exp(x)
    [Info] Calculating f of exp(x) => f(exp(x))
    [Info] Calculating f' of exp(x) => f'(exp(x))
    [Info] Calculating multiplication of 1 and f'(exp(x)) => 1·f'(exp(x))
    [Info] Calculating multiplication of 1·f'(exp(x)) and exp(x) => (1·f'(exp(x)))·exp(x)
    (1·f'(exp(x)))·exp(x)
    

    For illustration we can set 'f = sin' and 'x=2'

    \[ \left. \frac{d}{dx} \sin (e^x) \right|_{x=2} = e^2 \cos (e^2) \]

    >>> runStdoutLoggingT $ runKleisli (derivative (exp >>> sin)) (2 :: Float)
    [Info] Calculating exp of 2.0 => 7.389056
    [Info] Calculating sin of 7.389056 => 0.893855
    [Info] Calculating cos of 7.389056 => 0.44835615
    [Info] Calculating multiplication of 1.0 and 0.44835615 => 0.44835615
    [Info] Calculating multiplication of 0.44835615 and 7.389056 => 3.312929
    3.312929
    

    The first thing to mention in these logs is that the last forward step sin(exp x) is still computed, unlike the examples from the previous section. This is due to the monadic nature of the calculation chain, that must disappear as soon as we return to (->) from Kleisli m.

    The second thing to mention here is that the exponent exp x is calculated only once thanks to the cache term passed from the forward to the backward method.