15 Atomic functions

Custom functions and derivatives can be added to the TMB library. This may be necessary for the following reasons:

  • Adaptive (e.g. iterative) algorithms cannot be represented by a fixed computational graph and thus cannot be directly differentiated using TMB. Algorithms that use parameter dependent if-else branching are examples of such functions.
  • Some functions have so many floating point operations that it is infeasible to expand the computational graph. Memory usage may be greatly reduced in such cases by collapsing the computational graph to a singe node with multiple inputs and outputs.

15.1 Reverse mode differentiation

TMB uses CppAD as its engine for reverse mode derivatives. In order to add a new primitive function

\[f: R^n \rightarrow R^m\]

we must inform CppAD how to calculate derivatives of this function in reverse mode. That is, for any range space vector \(w \in R^m\) we must calculate the gradient of the function \(R^n \rightarrow R\) given by

\[ x \rightarrow \text{sum}( f(x) \odot w ) \]

where ‘\(\odot\)’ is pointwise multiplication.

15.2 Example: Adding new primitive function with known derivatives

As an example consider the Lambert W function defined implicitly by

\[y = W(y e^y)\]

Here, we only consider \(W\) as defined on the positive reals. It follows, by differentiating the above identity, that

\[ W'(x) = \frac{1}{ \exp\left(W(x)\right) \left(1 + W(x)\right) } \]

When coding reverse-mode derivatives we can assume that the function value \(W(x)\) has already been computed during a forward pass. For efficiency reasons we should use this intermediate calculation rather than re-calculating \(W(x)\) in the reverse pass.

We’ll assume that a plain C++ function (taking double types as input/output) is available to calculate \(W(x)\). It doesn’t matter whether you have the source code of an implementation or just the header with linkage to an external library:

double LambertW(double x);

The macro TMB_ATOMIC_VECTOR_FUNCTION() is used to declare our new primitive Lambert \(W\) function:

TMB_ATOMIC_VECTOR_FUNCTION(
    // ATOMIC_NAME
    LambertW
    ,
    // OUTPUT_DIM
    1,
    // ATOMIC_DOUBLE
    ty[0] = LambertW(tx[0]); // Call the 'double' version
    ,
    // ATOMIC_REVERSE
    Type W  = ty[0];                    // Function value from forward pass
    Type DW = 1. / (exp(W) * (1. + W)); // Derivative
    px[0] = DW * py[0];                 // Reverse mode chain rule
)

Let’s explain in detail what is going on. The macro takes four arguments:

  1. ATOMIC_NAME: Name of new primitive function taking CppAD::vector as input and output.
  2. OUTPUT_DIM: Dimension of the CppAD::vector which is the function output.
  3. ATOMIC_DOUBLE: Specifies how to evaluate the primitive function for the ordinary double type. tx denotes the input vector and ty the output vector of the function \(f: R^n \rightarrow R^m\). In this case both have dimension one.
  4. ATOMIC_REVERSE: How to calculate the reverse mode derivatives for a general Type. Again tx and ty denote function input and output but now ty has been computed and is available as an intermediate value. The vectors px and py denote partial derivatives of the end result with respect to \(x\) and \(y\) respectively. py is given and we must calculate px using the chain rule. This first order derivative rule is automatically expanded up to higher orders required when using TMB’s random effects calculations.

To make the function work like other TMB functions it is convenient to define scalar and a vectorized versions that call the atomic function:

// Scalar version
template<class Type>
Type LambertW(Type x){
  CppAD::vector<Type> tx(1);
  tx[0] = x;
  return LambertW(tx)[0];
}

// Vectorized version
VECTORIZE_1t(LambertW)

15.2.1 Testing the primitive function

Here is a complete example using Newton’s method to calculate the Lambert \(W\) function (there are more sophisticated algorithms such as the one by Fukushima (2013), but that doesn’t matter for this example):


#include <TMB.hpp>

// Double version of Lambert W function
double LambertW(double x) {
  double logx = log(x);
  double y = (logx > 0 ? logx : 0);
  int niter = 100, i=0;
  for (; i < niter; i++) {
    if ( fabs( logx - log(y) - y) < 1e-9) break;
    y -= (y - exp(logx - y)) / (1 + y);
  }
  if (i == niter) Rf_warning("W: failed convergence");
  return y;
}

TMB_ATOMIC_VECTOR_FUNCTION(
    // ATOMIC_NAME
    LambertW
    ,
    // OUTPUT_DIM
    1,
    // ATOMIC_DOUBLE
    ty[0] = LambertW(tx[0]); // Call the 'double' version
    ,
    // ATOMIC_REVERSE
    Type W  = ty[0];                    // Function value from forward pass
    Type DW = 1. / (exp(W) * (1. + W)); // Derivative
    px[0] = DW * py[0];                 // Reverse mode chain rule
)

// Scalar version
template<class Type>
Type LambertW(Type x){
  CppAD::vector<Type> tx(1);
  tx[0] = x;
  return LambertW(tx)[0];
}

// Vectorized version
VECTORIZE1_t(LambertW)

template<class Type>
Type objective_function<Type>::operator() ()
{
  PARAMETER_VECTOR(x);
  Type f = LambertW(x).sum();
  return f;
}

And from R

compile("lambert.cpp")
dyn.load(dynlib("lambert"))

15.2.1.1 Checking function value and derivatives

Check definition of the function:

obj <- MakeADFun(data=list(), parameters=list(x=1), DLL="lambert")
obj$fn(7 * exp(7))
## [1] 7

Check derivatives using the numDeriv package:

numDeriv::grad(obj$fn, 7)
## [1] 0.08626538
obj$gr(7)
##            [,1]
## [1,] 0.08626538

Also try second order derivatives:

numDeriv::hessian(obj$fn, 7)
##             [,1]
## [1,] -0.01038959
obj$he(7)
##             [,1]
## [1,] -0.01038969

15.3 Other approaches

For the Lambert \(W\) function we know how to calculate the derivatives. There are cases for which the derivatives are impossible (or difficult) to write down. If you’re in this situation you may want to try using forward mode AD to help in defining an atomic function. A full worked out example is available here: adaptive_integration.cpp. Derivatives are calculated automatically and if-else branching is allowed. The main downside with this approach is that it is limited to functions with very few inputs.

Checkpointing is another useful technique. It is demonstrated in the example register_atomic.cpp. It does not work for adaptive algorithms but is otherwise automatic. It is useful to reduce AD memory usage in cases where the same sequence of operations is being applied many times.