TMB Documentation  v1.9.11
compression.hpp
1 #ifndef HAVE_COMPRESSION_HPP
2 #define HAVE_COMPRESSION_HPP
3 // Autogenerated - do not edit by hand !
4 #include "global.hpp"
5 #include "graph_transform.hpp" // subset
6 #include "radix.hpp" // first_occurance
7 
8 namespace TMBad {
9 
11 struct period {
13  size_t begin;
15  size_t size;
17  size_t rep;
18 };
19 
20 std::ostream &operator<<(std::ostream &os, const period &x);
21 
40 template <class T>
41 struct periodic {
42  const std::vector<T> &x;
47  periodic(const std::vector<T> &x, size_t max_period_size,
48  size_t min_period_rep = 2)
49  : x(x),
50  max_period_size(max_period_size),
51  min_period_rep(min_period_rep) {}
56  bool test_period(size_t start, size_t p) {
57  if (start + (p - 1) + p >= x.size()) return false;
58  for (size_t i = 0; i < p; i++) {
59  if (x[start + i] != x[start + i + p]) return false;
60  }
61  return true;
62  }
68  size_t numrep_period(size_t start, size_t p) {
69  size_t n = 1;
70  while (test_period(start, p)) {
71  n++;
72  start += p;
73  }
74  return n;
75  }
87  period find_best_period(size_t start) {
88  size_t p_best = -1, rep_best = 0;
89  for (size_t p = 1; p < max_period_size; p++) {
90  size_t rep = numrep_period(start, p);
91  if (rep > rep_best) {
92  p_best = p;
93  rep_best = rep;
94  p = p * rep;
95  }
96  }
97  period ans = {start, p_best, rep_best};
98  return ans;
99  }
100  std::vector<period> find_all() {
101  std::vector<period> ans;
102  for (size_t i = 0; i < x.size();) {
103  period result = find_best_period(i);
104  if (result.rep >= min_period_rep) {
105  ans.push_back(result);
106  i += result.size * result.rep;
107  } else {
108  i++;
109  }
110  }
111  return ans;
112  }
113 };
114 
115 template <class T>
116 struct matrix_view {
117  const T *x;
118  size_t nrow, ncol;
119  matrix_view(const T *x, size_t nrow, size_t ncol)
120  : x(x), nrow(nrow), ncol(ncol) {}
121  T operator()(size_t i, size_t j) const { return x[i + j * nrow]; }
122  size_t rows() const { return nrow; }
123  size_t cols() const { return ncol; }
124  template <class Diff_T>
125  std::vector<Diff_T> row_diff(size_t i) {
126  size_t nd = (cols() >= 1 ? cols() - 1 : 0);
127  std::vector<Diff_T> xd(nd);
128  for (size_t j = 1; j < cols(); j++)
129  xd[j - 1] = (Diff_T)(*this)(i, j) - (Diff_T)(*this)(i, j - 1);
130  return xd;
131  }
132 };
133 
147 std::vector<period> split_period(global *glob, period p,
148  size_t max_period_size);
149 
150 struct compressed_input {
151  typedef std::ptrdiff_t ptrdiff_t;
152 
153  mutable std::vector<ptrdiff_t> increment_pattern;
154  std::vector<Index> which_periodic;
155  std::vector<Index> period_sizes;
156  std::vector<Index> period_offsets;
157  std::vector<ptrdiff_t> period_data;
158 
159  Index n, m;
160  Index nrep;
161  Index np;
162 
163  mutable Index counter;
164  mutable std::vector<Index> inputs;
165  std::vector<Index> input_diff;
166  size_t input_size() const;
167  void update_increment_pattern() const;
168 
169  void increment(Args<> &args) const;
170 
171  void decrement(Args<> &args) const;
172  void forward_init(Args<> &args) const;
173  void reverse_init(Args<> &args);
174  void dependencies_intervals(Args<> &args, std::vector<Index> &lower,
175  std::vector<Index> &upper) const;
176 
177  size_t max_period_size;
178 
179  bool test_period(std::vector<ptrdiff_t> &x, size_t p);
180 
181  size_t find_shortest(std::vector<ptrdiff_t> &x);
182  compressed_input();
183  compressed_input(std::vector<Index> &x, size_t offset, size_t nrow, size_t m,
184  size_t ncol, size_t max_period_size);
185 };
186 
187 template <class T1, class T2>
188 struct compare_types {
189  const static bool equal = false;
190 };
191 template <class T>
192 struct compare_types<T, T> {
193  const static bool equal = true;
194 };
195 
196 void compress(global &glob, size_t max_period_size);
197 struct StackOp : global::SharedDynamicOperator {
198  typedef std::ptrdiff_t ptrdiff_t;
199  global::operation_stack opstack;
200  compressed_input ci;
201  StackOp(global *glob, period p, IndexPair ptr, size_t max_period_size);
203  StackOp(const StackOp &x);
204  void print(global::print_config cfg);
205  Index input_size() const;
206  Index output_size() const;
207  static const bool have_input_size_output_size = true;
214  template <class Type>
215  void forward(ForwardArgs<Type> args) {
216  ci.forward_init(args);
217 
218  size_t opstack_size = opstack.size();
219  for (size_t i = 0; i < ci.nrep; i++) {
220  for (size_t j = 0; j < opstack_size; j++) {
221  opstack[j]->forward_incr(args);
222  }
223  ci.increment(args);
224  }
225  if (compare_types<Type, Replay>::equal) {
226  compress(*get_glob(), ci.max_period_size);
227  }
228  }
229  void forward(ForwardArgs<Writer> &args);
236  template <class Type>
237  void reverse(ReverseArgs<Type> args) {
238  ci.reverse_init(args);
239  size_t opstack_size = opstack.size();
240  for (size_t i = 0; i < ci.nrep; i++) {
241  ci.decrement(args);
242 
243  for (size_t j = opstack_size; j > 0;) {
244  j--;
245  opstack[j]->reverse_decr(args);
246  }
247  }
248  if (compare_types<Type, Replay>::equal) {
249  compress(*get_glob(), ci.max_period_size);
250  }
251  }
252  void reverse(ReverseArgs<Writer> &args);
257  void dependencies(Args<> args, Dependencies &dep) const;
259  static const bool have_dependencies = true;
261  static const bool implicit_dependencies = true;
263  static const bool allow_remap = false;
264  const char *op_name();
265 };
266 
267 template <class T>
268 void trim(std::vector<T> &v, const T &elt) {
269  v.erase(std::remove(v.begin(), v.end(), elt), v.end());
270 }
271 
272 template <class T>
273 struct toposort_remap {
274  std::vector<T> &remap;
275  T i;
276  toposort_remap(std::vector<T> &remap, T i) : remap(remap), i(i) {}
277  void operator()(Index k) {
278  if (remap[k] >= remap[i]) {
279  remap[i] = i;
280  }
281  }
282 };
283 
308 void reorder_sub_expressions(global &glob);
309 
310 template <class T>
311 struct temporaries_remap {
312  std::vector<T> &remap;
313  T i;
314  temporaries_remap(std::vector<T> &remap, T i) : remap(remap), i(i) {}
315  void operator()(Index k) {
316  if (remap[k] == T(-1)) {
317  if (i > k + 1) remap[k] = i;
318  return;
319  }
320 
321  remap[k] = k;
322  }
323 };
324 
332 void reorder_temporaries(global &glob);
333 
334 template <class T>
335 struct dfs_add_to_stack {
336  std::vector<T> &stack;
337  std::vector<bool> &visited;
338  std::vector<T> &v2o;
339  dfs_add_to_stack(std::vector<T> &stack, std::vector<bool> &visited,
340  std::vector<T> &v2o)
341  : stack(stack), visited(visited), v2o(v2o) {}
342  void operator()(T var) {
343  Index op = v2o[var];
344  if (!visited[op]) {
345  stack.push_back(op);
346  visited[op] = true;
347  }
348  }
349 };
350 
354 void reorder_depth_first(global &glob);
355 
356 void compress(global &glob, size_t max_period_size = 1024);
357 
358 } // namespace TMBad
359 #endif // HAVE_COMPRESSION_HPP
Automatic differentiation library designed for TMB.
Definition: TMB.hpp:157
void reorder_temporaries(global &glob)
Re-order computational graph to make it more compressible.
Definition: TMBad.cpp:567
size_t rep
Number of consecutive period replicates.
Definition: compression.hpp:17
size_t size
Size of the period.
Definition: compression.hpp:15
global * get_glob()
Get pointer to current global AD context (or NULL if no context is active).
Definition: TMBad.cpp:690
Access input/output values and derivatives during a reverse pass. Write access granted for the input ...
Definition: global.hpp:311
size_t min_period_rep
Ignore periods with too small replicator.
Definition: compression.hpp:46
void reorder_sub_expressions(global &glob)
Re-order computational graph to make it more compressible.
Definition: TMBad.cpp:537
void reorder_depth_first(global &glob)
Depth-first reordering of computational graph.
Definition: TMBad.cpp:595
Access input/output values during a forward pass. Write access granted for the output value only...
Definition: global.hpp:279
period find_best_period(size_t start)
Find the best period (highest compression degree) starting from a given index.
Definition: compression.hpp:87
Period analyzer.
Definition: compression.hpp:41
Configuration of print method.
Definition: global.hpp:1479
Representation of a period in a sequence.
Definition: compression.hpp:11
size_t begin
Where does the period begin.
Definition: compression.hpp:13
Struct defining the main AD context.
Definition: global.hpp:797
std::vector< period > split_period(global *glob, period p, size_t max_period_size)
Helper.
Definition: TMBad.cpp:190
bool test_period(size_t start, size_t p)
Test period one step ahead.
Definition: compression.hpp:56
size_t numrep_period(size_t start, size_t p)
Find number period replicates.
Definition: compression.hpp:68
size_t max_period_size
Test periods up to this size.
Definition: compression.hpp:44
License: GPL v2