TMB Documentation  v1.9.11
checkpoint.hpp
1 #ifndef HAVE_CHECKPOINT_HPP
2 #define HAVE_CHECKPOINT_HPP
3 // Autogenerated - do not edit by hand !
4 #include <memory>
5 #include "global.hpp"
6 #include "vectorize.hpp"
7 
8 namespace TMBad {
9 
11 template <class ADFun, bool packed_ = false>
12 struct standard_derivative_table : std::vector<ADFun> {
13  static const bool packed = packed_;
15  void requireOrder(size_t n) {
16  while ((*this).size() <= n) {
17  (*this).push_back((*this).back().WgtJacFun());
18  }
19  }
21  void retape(ForwardArgs<Scalar> &args) {}
23  standard_derivative_table(const ADFun &F) : std::vector<ADFun>(1, F) {}
24 };
25 
32  static const bool packed = false;
33  std::vector<Scalar> x_prev;
34  bool operator()(const std::vector<Scalar> &x);
35 };
36 
38 template <class Functor, class ADFun, class Test = ParametersChanged,
39  bool packed_ = false>
41  Functor F;
42  Test test;
46  size_t n = (*this)[0].Domain();
47  std::vector<Scalar> x = args.x_segment(0, n);
48  bool change = test(x);
49  if (change) {
50  (*this).resize(1);
51  (*this)[0] = ADFun(F, x);
52  }
53  }
56  template <class V>
57  retaping_derivative_table(const Functor &F, const V &x, Test test = Test())
58  : standard_derivative_table<ADFun, packed_>(ADFun(F, x)),
59  F(F),
60  test(test) {}
61 };
62 
94 template <class T>
96  typedef std::shared_ptr<T> Base;
97  Base sp;
98  std::shared_ptr<std::vector<std::weak_ptr<T> > > weak_refs;
99 
100  omp_shared_ptr(const Base &x)
101  : sp(x), weak_refs(std::make_shared<std::vector<std::weak_ptr<T> > >()) {
102  (*weak_refs).resize(TMBAD_MAX_NUM_THREADS);
103  (*weak_refs)[TMBAD_THREAD_NUM] = x;
104  }
105  omp_shared_ptr(const omp_shared_ptr &other) : weak_refs(other.weak_refs) {
106  if ((*weak_refs)[TMBAD_THREAD_NUM].expired()) {
107  sp = std::make_shared<T>(*other);
108 
109  (*weak_refs)[TMBAD_THREAD_NUM] = sp;
110  } else {
111  sp = (*weak_refs)[TMBAD_THREAD_NUM].lock();
112  }
113  }
114  omp_shared_ptr() {}
115  T &operator*() const { return *sp; }
116  T *operator->() const { return sp.get(); }
117  explicit operator bool() const { return (bool)sp; }
118 };
119 
166 template <class DerivativeTable>
167 struct AtomOp : global::DynamicOperator<-1, -1> {
168  static const bool have_input_size_output_size = true;
169  static const bool add_forward_replay_copy = true;
170 
171  TMBAD_SHARED_PTR<DerivativeTable> dtab;
172 
173  int order;
174 
175  template <class T1>
176  AtomOp(const T1 &F) : dtab(std::make_shared<DerivativeTable>(F)), order(0) {}
177  template <class T1, class T2>
178  AtomOp(const T1 &F, const T2 &x)
179  : dtab(std::make_shared<DerivativeTable>(F, x)), order(0) {}
180  template <class T1, class T2, class T3>
181  AtomOp(const T1 &F, const T2 &x, const T3 &t)
182  : dtab(std::make_shared<DerivativeTable>(F, x, t)), order(0) {}
183 
184  Index input_size() const { return (*dtab)[order].Domain(); }
185  Index output_size() const { return (*dtab)[order].Range(); }
186 
187  void forward(ForwardArgs<Scalar> &args) {
188  (*dtab).retape(args);
189 
190  (*dtab).requireOrder(order);
191 
192  size_t n = input_size();
193  size_t m = output_size();
194 
195  auto x = args.x_segment(0, n);
196 
197  args.y_segment(0, m) = (*dtab)[order](x);
198  }
199 
200  void reverse(ReverseArgs<Scalar> &args) {
201  size_t n = input_size();
202  size_t m = output_size();
203 
204  auto x = args.x_segment(0, n);
205  auto w = args.dy_segment(0, m);
206 
207  args.dx_segment(0, n) += (*dtab)[order].Jacobian(x, w);
208  }
209 
210  void reverse(ReverseArgs<global::Replay> &args) {
211  size_t n = input_size();
212  size_t m = output_size();
213 
214  std::vector<global::Replay> x = args.x_segment(0, n);
215  if (DerivativeTable::packed) x = repack(x);
216  std::vector<global::Replay> w = args.dy_segment(0, m);
217  std::vector<global::Replay> xw;
218  xw.insert(xw.end(), x.begin(), x.end());
219  xw.insert(xw.end(), w.begin(), w.end());
220 
221  (*dtab).requireOrder(order + 1);
222  AtomOp cpy(*this);
223  cpy.order++;
224  args.dx_segment(0, n) += global::Complete<AtomOp>(cpy)(xw);
225  }
226 
227  template <class T>
228  void forward(ForwardArgs<T> &args) {
229  TMBAD_ASSERT(false);
230  }
231  void reverse(ReverseArgs<Writer> &args) { TMBAD_ASSERT(false); }
232 
233  const char *op_name() { return "AtomOp"; }
234 
235  void print(global::print_config cfg) {
236  Rcout << cfg.prefix;
237  Rcout << "order=" << order << " ";
238  Rcout << "(*dtab).size()=" << (*dtab).size() << " ";
239  Rcout << "dtab=" << &(*dtab) << "\n";
240  (*dtab)[order].print(cfg);
241  }
242 };
243 
252 template <class Functor>
253 struct PackWrap {
254  Functor F;
255  PackWrap(const Functor &F) : F(F) {}
258  template <class T>
259  std::vector<T> operator()(const std::vector<T> &xp) {
260  Index K = ScalarPack<SegmentRef>::size;
261  size_t n = xp.size() / K;
262  TMBAD_ASSERT2(n * K == xp.size(), "Invalid packed arguments");
263  std::vector<ad_segment> x(n);
264  for (size_t i = 0; i < n; i++) x[i] = unpack(xp, i);
265  ad_segment y = F(x);
266  ad_segment yp = pack(y);
267  std::vector<T> ans = concat(std::vector<ad_segment>(1, yp));
268  return ans;
269  }
272  bool operator()(const std::vector<Scalar> &xp) {
273  Index K = ScalarPack<SegmentRef>::size;
274  size_t n = xp.size() / K;
275  TMBAD_ASSERT2(n * K == xp.size(), "Invalid packed arguments");
276  std::vector<Scalar *> x(n);
277  for (size_t i = 0; i < n; i++) x[i] = unpack(xp, i);
278  return F(x);
279  }
280 };
281 
282 } // namespace TMBad
283 #endif // HAVE_CHECKPOINT_HPP
Automatic differentiation library designed for TMB.
Definition: TMB.hpp:157
segment_ref< ReverseArgs, dx_write > dx_segment(Index from, Index size)
segment version
Definition: global.hpp:344
Vector class used by TMB.
Definition: vector.hpp:17
segment_ref< ReverseArgs, x_read > x_segment(Index from, Index size)
segment version
Definition: global.hpp:336
segment_ref< ForwardArgs, x_read > x_segment(Index from, Index size)
segment version
Definition: global.hpp:293
bool operator()(const std::vector< Scalar > &xp)
Transformed &#39;tester&#39; assuming original maps std::vector<Scalar*> to bool
Definition: checkpoint.hpp:272
void retape(ForwardArgs< Scalar > &args)
Retape the zero derivative and remove all higher orders from the table.
Definition: checkpoint.hpp:45
segment_ref< ForwardArgs, y_write > y_segment(Index from, Index size)
segment version
Definition: global.hpp:297
Access input/output values and derivatives during a reverse pass. Write access granted for the input ...
Definition: global.hpp:311
Contiguous set of variables on the current tape.
Definition: global.hpp:2780
Transform a functor to have packed input/output.
Definition: checkpoint.hpp:253
Automatic differentiation function object.
Definition: TMBad.hpp:117
vector< Type > operator*(matrix< Type > A, vector< Type > x)
Definition: convenience.hpp:42
Default tester for retaping_derivative_table.
Definition: checkpoint.hpp:31
Configuration of print method.
Definition: global.hpp:1479
segment_ref< ReverseArgs, dy_read > dy_segment(Index from, Index size)
segment version
Definition: global.hpp:348
void requireOrder(size_t n)
Add derivatives up to this order.
Definition: checkpoint.hpp:15
retaping_derivative_table(const Functor &F, const V &x, Test test=Test())
Set zero order functor used to retape this derivative table.
Definition: checkpoint.hpp:57
Manage shared operator data across multiple threads.
Definition: checkpoint.hpp:95
ad_segment unpack(const ad_segment &x)
Unpack consecutive values on the tape.
Definition: TMBad.cpp:4653
ad_segment pack(const ad_segment &x)
Pack consecutive values on the tape.
Definition: TMBad.cpp:4648
std::vector< size_t > order(std::vector< T > x)
Get permutation that sorts a vector.
Operator auto-completion.
Definition: global.hpp:2129
Operator that requires dynamic allocation. Compile time known input/output size.
Definition: global.hpp:1590
Fixed derivative table used by AtomOp
Definition: checkpoint.hpp:12
std::vector< T > operator()(const std::vector< T > &xp)
Transformed functor assuming original maps std::vector<ad_segment> to ad_segment
Definition: checkpoint.hpp:259
void retape(ForwardArgs< Scalar > &args)
Retaping this derivative table has no effect.
Definition: checkpoint.hpp:21
standard_derivative_table(const ADFun &F)
Set zero order function of this derivative table.
Definition: checkpoint.hpp:23
Generic checkpoint operator.
Definition: checkpoint.hpp:167
Adaptive derivative table used by AtomOp
Definition: checkpoint.hpp:40
License: GPL v2