15 Atomic functions
Custom functions and derivatives can be added to the TMB library. This may be necessary for the following reasons:
- Adaptive (e.g. iterative) algorithms cannot be represented by a fixed computational graph and thus cannot be directly differentiated using TMB. Algorithms that use parameter dependent if-else branching are examples of such functions.
- Some functions have so many floating point operations that it is infeasible to expand the computational graph. Memory usage may be greatly reduced in such cases by collapsing the computational graph to a singe node with multiple inputs and outputs.
15.1 Reverse mode differentiation
TMB uses CppAD as its engine for reverse mode derivatives. In order to add a new primitive function
\[f: R^n \rightarrow R^m\]
we must inform CppAD how to calculate derivatives of this function in reverse mode. That is, for any range space vector \(w \in R^m\) we must calculate the gradient of the function \(R^n \rightarrow R\) given by
\[ x \rightarrow \text{sum}( f(x) \odot w ) \]
where ‘\(\odot\)’ is pointwise multiplication.
15.2 Example: Adding new primitive function with known derivatives
As an example consider the Lambert W function defined implicitly by
\[y = W(y e^y)\]
Here, we only consider \(W\) as defined on the positive reals. It follows, by differentiating the above identity, that
\[ W'(x) = \frac{1}{ \exp\left(W(x)\right) \left(1 + W(x)\right) } \]
When coding reverse-mode derivatives we can assume that the function value \(W(x)\) has already been computed during a forward pass. For efficiency reasons we should use this intermediate calculation rather than re-calculating \(W(x)\) in the reverse pass.
We’ll assume that a plain C++ function (taking double types as input/output) is available to calculate \(W(x)\). It doesn’t matter whether you have the source code of an implementation or just the header with linkage to an external library:
double LambertW(double x);
The macro TMB_ATOMIC_VECTOR_FUNCTION()
is used to declare our new primitive Lambert \(W\) function:
TMB_ATOMIC_VECTOR_FUNCTION(
// ATOMIC_NAME
LambertW
,// OUTPUT_DIM
1,
// ATOMIC_DOUBLE
0] = LambertW(tx[0]); // Call the 'double' version
ty[
,// ATOMIC_REVERSE
0]; // Function value from forward pass
Type W = ty[1. / (exp(W) * (1. + W)); // Derivative
Type DW = 0] = DW * py[0]; // Reverse mode chain rule
px[ )
Let’s explain in detail what is going on. The macro takes four arguments:
ATOMIC_NAME
: Name of new primitive function takingCppAD::vector
as input and output.OUTPUT_DIM
: Dimension of theCppAD::vector
which is the function output.ATOMIC_DOUBLE
: Specifies how to evaluate the primitive function for the ordinary double type.tx
denotes the input vector andty
the output vector of the function \(f: R^n \rightarrow R^m\). In this case both have dimension one.ATOMIC_REVERSE
: How to calculate the reverse mode derivatives for a generalType
. Againtx
andty
denote function input and output but nowty
has been computed and is available as an intermediate value. The vectorspx
andpy
denote partial derivatives of the end result with respect to \(x\) and \(y\) respectively.py
is given and we must calculatepx
using the chain rule. This first order derivative rule is automatically expanded up to higher orders required when using TMB’s random effects calculations.
To make the function work like other TMB functions it is convenient to define scalar and a vectorized versions that call the atomic function:
// Scalar version
template<class Type>
x){
Type LambertW(Type CppAD::vector<Type> tx(1);
0] = x;
tx[return LambertW(tx)[0];
}
// Vectorized version
VECTORIZE_1t(LambertW)
15.2.1 Testing the primitive function
Here is a complete example using Newton’s method to calculate the Lambert \(W\) function (there are more sophisticated algorithms such as the one by Fukushima (2013), but that doesn’t matter for this example):
#include <TMB.hpp>
// Double version of Lambert W function
double LambertW(double x) {
double logx = log(x);
double y = (logx > 0 ? logx : 0);
int niter = 100, i=0;
for (; i < niter; i++) {
if ( fabs( logx - log(y) - y) < 1e-9) break;
1 + y);
y -= (y - exp(logx - y)) / (
}if (i == niter) Rf_warning("W: failed convergence");
return y;
}
TMB_ATOMIC_VECTOR_FUNCTION(
// ATOMIC_NAME
LambertW
,// OUTPUT_DIM
1,
// ATOMIC_DOUBLE
0] = LambertW(tx[0]); // Call the 'double' version
ty[
,// ATOMIC_REVERSE
0]; // Function value from forward pass
Type W = ty[1. / (exp(W) * (1. + W)); // Derivative
Type DW = 0] = DW * py[0]; // Reverse mode chain rule
px[
)
// Scalar version
template<class Type>
x){
Type LambertW(Type CppAD::vector<Type> tx(1);
0] = x;
tx[return LambertW(tx)[0];
}
// Vectorized version
VECTORIZE1_t(LambertW)
template<class Type>
operator() ()
Type objective_function<Type>::
{PARAMETER_VECTOR(x);
x).sum();
Type f = LambertW(return f;
}
And from R
compile("lambert.cpp")
dyn.load(dynlib("lambert"))
15.2.1.1 Checking function value and derivatives
Check definition of the function:
<- MakeADFun(data=list(), parameters=list(x=1), DLL="lambert")
obj $fn(7 * exp(7)) obj
## [1] 7
Check derivatives using the numDeriv
package:
::grad(obj$fn, 7) numDeriv
## [1] 0.08626538
$gr(7) obj
## [,1]
## [1,] 0.08626538
Also try second order derivatives:
::hessian(obj$fn, 7) numDeriv
## [,1]
## [1,] -0.01038959
$he(7) obj
## [,1]
## [1,] -0.01038969
15.3 Other approaches
For the Lambert \(W\) function we know how to calculate the derivatives. There are cases for which the derivatives are impossible (or difficult) to write down. If you’re in this situation you may want to try using forward mode AD to help in defining an atomic function. A full worked out example is available here: adaptive_integration.cpp. Derivatives are calculated automatically and if-else branching is allowed. The main downside with this approach is that it is limited to functions with very few inputs.
Checkpointing is another useful technique. It is demonstrated in the example register_atomic.cpp. It does not work for adaptive algorithms but is otherwise automatic. It is useful to reduce AD memory usage in cases where the same sequence of operations is being applied many times.