Copyright | (c) Marco Zocca 2021 |
---|---|
License | BSD |
Maintainer | github.com/ocramz |
Stability | experimental |
Portability | POSIX |
Safe Haskell | Safe-Inferred |
Language | Haskell2010 |
Quickstart
Most users will only need to import rad1
, rad2
or grad
and leverage the Num
, Fractional
, Floating
instances of the AD
type.
Similarly to ad
, a user supplies a polymorphic function to be differentiated, e.g.
f :: Num a => a -> a f x = x + (x * x)
and the library takes care of the rest :
>>> rad1
f 1.2
(2.6399999999999997,3.4000000000000004)
grad
computes the gradient of a scalar function of vector argument. For full generality, the argument can be any Traversable
container (a list, an array, a dictionary .. )
sqNorm :: Num a => [a] -> a sqNorm xs = sum $ zipWith (*) xs xs p :: [Double] p = [4.1, 2]
>>> grad
sqNorm p
(20.81,[8.2,4.0])
It's important to emphasize that the library cannot differentiate functions of concrete types, e.g. Double -> Double
. On the other hand, it's easy to experiment with other numerical interfaces that support one, zero and plus.
Advanced usage
The library is small and easily extensible.
For example, a user might want to supply their own numerical typeclass other than Num
, and build up a library of AD
combinators based on that, specializing op1
and op2
with custom implementations of zero
, one
and plus
. This insight first appeared in the user interface of backprop
, as the Backprop typeclass.
Exposing unconstrained AD combinators lets users specialize this library to e.g. exotic number-like types or discrete data structures such as dictionaries, automata etc.
Implementation details and design choices
This is the first (known) Haskell implementation of the ideas presented in Wang et al. Here the role of variable mutation and delimited continuations is made explicit by the use of ST
and ContT
, as compared to the reference Scala implementation.
ad-delcont
relies on non-standard interpretation of the user-provided function; in order to compute the adjoint values (the sensitivities) of the function parameters, the function is first evaluated ("forwards"), while keeping track of continuation points, and all the intermediate adjoints are accumulated upon returning from the respective continuations ("backwards") via safe mutation in the ST monad.
As a result of this design, the main AD
type cannot be given Eq
and Ord
instances (since it's unclear how equality and ordering predicates would apply to continuations and state threads).
The user interface is inspired by that of ad
and backprop
, however the internals are completely different in that this library doesn't reify the function to be differentiated into a "tape" data structure.
Another point in common with backprop
is that users can differentiate heterogeneous functions: the input and output types can be different. This makes it possible to differentiate functions of statically-typed vectors and matrices.
References
backprop
- https://hackage.haskell.org/package/backpropad
- https://hackage.haskell.org/package/ad- F. Wang et al, Backpropagation with Continuation Callbacks : Foundations for Efficient and Expressive Differentiable Programming, NeurIPS 2018 - https://papers.nips.cc/paper/2018/file/34e157766f31db3d2099831d348a7933-Paper.pdf
- F. Wang et al, Demystifying Differentiable Programming : Shift/Reset the Penultimate Backpropagator, ICFP 2019 - https://doi.org/10.1145/3341700 - https://www.cs.purdue.edu/homes/rompf/papers/wang-icfp19.pdf
- M. Innes, Don't unroll adjoint: Differentiating SSA-Form Programs https://arxiv.org/abs/1810.07951
Synopsis
- rad1 :: (Num a, Num b) => (forall s. AD' s a -> AD' s b) -> a -> (b, a)
- rad2 :: (Num a, Num b, Num c) => (forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
- grad :: (Traversable t, Num a, Num b) => (forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
- auto :: a -> AD s a da
- rad1g :: da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
- rad2g :: da -> db -> dc -> (forall s. AD s a da -> AD s b db -> AD s c dc) -> a -> b -> (c, (da, db))
- radNg :: Traversable t => da -> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
- op1 :: db -> (da -> da -> da) -> (a -> (b, db -> da)) -> AD s a da -> AD s b db
- op2 :: dc -> (da -> da -> da) -> (db -> db -> db) -> (a -> b -> (c, dc -> da, dc -> db)) -> AD s a da -> AD s b db -> AD s c dc
- op1Num :: (Num da, Num db) => (a -> (b, db -> da)) -> AD s a da -> AD s b db
- op2Num :: (Num da, Num db, Num dc) => (a -> b -> (c, dc -> da, dc -> db)) -> AD s a da -> AD s b db -> AD s c dc
- data AD0 s a
- type AD s a da = AD0 s (DVar s a da)
- type AD' s a = AD s a a
Quickstart
:: (Num a, Num b) | |
=> (forall s. AD' s a -> AD' s b) | function to be differentiated |
-> a | function argument |
-> (b, a) | (result, adjoint) |
Evaluate (forward mode) and differentiate (reverse mode) a unary function
>>>
rad1 (\x -> x * x) 1
(1, 2)
:: (Num a, Num b, Num c) | |
=> (forall s. AD' s a -> AD' s b -> AD' s c) | function to be differentiated |
-> a | |
-> b | |
-> (c, (a, b)) | (result, adjoints) |
Evaluate (forward mode) and differentiate (reverse mode) a binary function
>>>
rad2 (\x y -> x + y + y) 1 1
(1,2)
>>>
rad2 (\x y -> (x + y) * x) 3 2
(15,(8,3))
:: (Traversable t, Num a, Num b) | |
=> (forall s. t (AD' s a) -> AD' s b) | |
-> t a | argument vector |
-> (b, t a) | (result, gradient vector) |
Evaluate (forward mode) and differentiate (reverse mode) a function of a Traversable
In linear algebra terms, this computes the gradient of a scalar function of vector argument
sqNorm :: Num a => [a] -> a sqNorm xs = sum $ zipWith (*) xs xs p :: [Double] p = [4.1, 2]
>>>
grad sqNorm p
(20.81,[8.2,4.0])
auto :: a -> AD s a da Source #
Lift a constant value into AD
As one expects from a constant, its value will be used for computing the result, but it will be discarded when computing the sensitivities.
Advanced usage
:: da | zero |
-> db | one |
-> (forall s. AD s a da -> AD s b db) | |
-> a | function argument |
-> (b, da) | (result, adjoint) |
Evaluate (forward mode) and differentiate (reverse mode) a unary function, without committing to a specific numeric typeclass
:: da | zero |
-> db | zero |
-> dc | one |
-> (forall s. AD s a da -> AD s b db -> AD s c dc) | |
-> a | |
-> b | |
-> (c, (da, db)) | (result, adjoints) |
Evaluate (forward mode) and differentiate (reverse mode) a binary function, without committing to a specific numeric typeclass
:: Traversable t | |
=> da | zero |
-> db | one |
-> (forall s. t (AD s a da) -> AD s b db) | |
-> t a | argument vector |
-> (b, t da) | (result, gradient vector) |
Evaluate (forward mode) and differentiate (reverse mode) a function of a Traversable
In linear algebra terms, this computes the gradient of a scalar function of vector argument
Lift functions into AD
:: db | zero |
-> (da -> da -> da) | plus |
-> (a -> (b, db -> da)) | returns : (function result, pullback) |
-> AD s a da | |
-> AD s b db |
Lift a unary function
The first two arguments constrain the types of the adjoint values of the output and input variable respectively, see op1Num
for an example.
The third argument is the most interesting: it specifies at once how to compute the function value and how to compute the sensitivity with respect to the function parameter.
Note : the type parameters are completely unconstrained.
:: dc | zero |
-> (da -> da -> da) | plus |
-> (db -> db -> db) | plus |
-> (a -> b -> (c, dc -> da, dc -> db)) | returns : (function result, pullbacks) |
-> AD s a da -> AD s b db -> AD s c dc |
Lift a binary function
See op1
for more information.
Num instances
:: (Num da, Num db) | |
=> (a -> (b, db -> da)) | returns : (function result, pullback) |
-> AD s a da | |
-> AD s b db |
Helper for constructing unary functions that operate on Num instances (i.e. op1
specialized to Num)
:: (Num da, Num db, Num dc) | |
=> (a -> b -> (c, dc -> da, dc -> db)) | returns : (function result, pullback) |
-> AD s a da | |
-> AD s b db | |
-> AD s c dc |
Helper for constructing binary functions that operate on Num instances (i.e. op2
specialized to Num)
Types
Mutable references in the continuation monad
Instances
Applicative (AD0 s) Source # | |
Functor (AD0 s) Source # | |
Floating a => Floating (AD s a a) Source # | |
Defined in Numeric.AD.DelCont.Internal sqrt :: AD s a a -> AD s a a # (**) :: AD s a a -> AD s a a -> AD s a a # logBase :: AD s a a -> AD s a a -> AD s a a # asin :: AD s a a -> AD s a a # acos :: AD s a a -> AD s a a # atan :: AD s a a -> AD s a a # sinh :: AD s a a -> AD s a a # cosh :: AD s a a -> AD s a a # tanh :: AD s a a -> AD s a a # asinh :: AD s a a -> AD s a a # acosh :: AD s a a -> AD s a a # atanh :: AD s a a -> AD s a a # log1p :: AD s a a -> AD s a a # expm1 :: AD s a a -> AD s a a # | |
Num a => Num (AD s a a) Source # | The numerical methods of (Num, Fractional, Floating etc.) can be read off their |
Fractional a => Fractional (AD s a a) Source # | |