5 SpJacFun_config::SpJacFun_config() : compress(false), index_remap(true) {}
11 vmatrix
matmul(
const vmatrix &x,
const vmatrix &y) {
12 vmatrix z(x.rows(), y.cols());
13 Map<vmatrix> zm(&z(0), z.rows(), z.cols());
14 matmul<false, false, false, false>(x, y, zm);
18 dmatrix
matmul(
const dmatrix &x,
const dmatrix &y) {
return x * y; }
21 #include "checkpoint.hpp" 24 bool ParametersChanged::operator()(
const std::vector<Scalar> &x) {
25 bool change = (x != x_prev);
33 #include "code_generator.hpp" 36 void searchReplace(std::string &str,
const std::string &oldStr,
37 const std::string &newStr) {
38 std::string::size_type pos = 0u;
39 while ((pos = str.find(oldStr, pos)) != std::string::npos) {
40 str.replace(pos, oldStr.length(), newStr);
41 pos += newStr.length();
45 std::string code_config::float_ptr() {
return float_str + (gpu ?
"**" :
"*"); }
47 std::string code_config::void_str() {
48 return (gpu ?
"__device__ void" :
"extern \"C\" void");
51 void code_config::init_code() {
53 *cout << indent <<
"int idx = threadIdx.x;" << std::endl;
57 void code_config::write_header_comment() {
58 if (header_comment.length() > 0) *cout << header_comment << std::endl;
61 code_config::code_config()
65 header_comment(
"// Autogenerated - do not edit by hand !"),
66 float_str(xstringify(TMBAD_SCALAR_TYPE)),
69 void write_common(std::ostringstream &buffer, code_config cfg,
size_t node) {
70 std::ostream &cout = *cfg.cout;
74 std::string indent = cfg.indent;
76 cout << indent <<
"asm(\"// Node: " << node <<
"\");" << endl;
77 bool empty_buffer = (buffer.tellp() == 0);
79 std::string str = buffer.str();
81 std::string pattern =
"]";
82 std::string replace =
"][idx]";
83 searchReplace(str, pattern, replace);
85 searchReplace(str,
";v",
"; v");
86 searchReplace(str,
";d",
"; d");
87 cout << indent << str << endl;
91 void write_forward(
global &glob, code_config cfg) {
95 std::ostream &cout = *cfg.cout;
96 cfg.write_header_comment();
97 cout << cfg.void_str() <<
" forward(" << cfg.float_ptr() <<
" v) {" << endl;
100 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
101 std::ostringstream buffer;
102 Writer::cout = &buffer;
103 glob.
opstack[i]->forward(args);
104 write_common(buffer, cfg, i);
110 void write_reverse(
global &glob, code_config cfg) {
114 std::ostream &cout = *cfg.cout;
115 cfg.write_header_comment();
116 cout << cfg.void_str() <<
" reverse(" << cfg.float_ptr() <<
" v, " 117 << cfg.float_ptr() <<
" d) {" << endl;
120 for (
size_t i = glob.
opstack.size(); i > 0;) {
123 std::ostringstream buffer;
124 Writer::cout = &buffer;
125 glob.
opstack[i]->reverse(args);
126 write_common(buffer, cfg, i);
131 void write_all(
global glob, code_config cfg) {
135 std::ostream &cout = *cfg.cout;
136 cout <<
"#include \"global.hpp\"" << endl;
137 cout <<
"#include \"ad_blas.hpp\"" << endl;
138 write_forward(glob, cfg);
139 write_reverse(glob, cfg);
140 cout <<
"int main() {}" << endl;
145 #include "compile.hpp" 148 void compile(
global &glob, code_config cfg) {
150 cfg.asm_comments =
false;
152 file.open(
"tmp.cpp");
155 *cfg.cout <<
"#include <cmath>" << std::endl;
157 <<
"template<class T>T sign(const T &x) { return (x > 0) - (x < 0); }" 160 write_forward(glob, cfg);
162 write_reverse(glob, cfg);
164 int out = system(
"g++ -O3 -g tmp.cpp -o tmp.so -shared -fPIC");
168 void *handle = dlopen(
"./tmp.so", RTLD_NOW);
169 if (handle != NULL) {
170 Rcout <<
"Loading compiled code!" << std::endl;
172 reinterpret_cast<void (*)(Scalar *)
>(dlsym(handle,
"forward"));
174 dlsym(handle,
"reverse"));
180 #include "compression.hpp" 183 std::ostream &operator<<(std::ostream &os,
const period &x) {
184 os <<
"begin: " << x.
begin;
185 os <<
" size: " << x.
size;
186 os <<
" rep: " << x.
rep;
191 size_t max_period_size) {
192 typedef std::ptrdiff_t ptrdiff_t;
195 size_t offset = glob->subgraph_ptr[p.
begin].first;
198 for (
size_t i = 0; i < p.
size; i++) {
204 matrix_view<Index> x(&(glob->
inputs[offset]), nrow, ncol);
206 std::vector<bool> marks(ncol - 1,
false);
208 for (
size_t i = 0; i < nrow; i++) {
209 std::vector<period> pd =
213 for (
size_t j = 0; j < pd.size(); j++) {
214 if (pd[j].begin > 0) {
215 marks[pd[j].begin - 1] =
true;
217 size_t end = pd[j].begin + pd[j].size * pd[j].rep;
218 if (end < marks.size()) marks[end] =
true;
222 std::vector<period> ans;
225 for (
size_t j = 0; j < marks.size(); j++) {
239 size_t compressed_input::input_size()
const {
return n; }
241 void compressed_input::update_increment_pattern()
const {
242 for (
size_t i = 0; i < (size_t)np; i++)
243 increment_pattern[which_periodic[i]] =
244 period_data[period_offsets[i] + counter % period_sizes[i]];
247 void compressed_input::increment(
Args<> &args)
const {
249 update_increment_pattern();
252 for (
size_t i = 0; i < n; i++) inputs[i] += increment_pattern[i];
256 void compressed_input::decrement(
Args<> &args)
const {
257 args.
ptr.first = input_size();
258 for (
size_t i = 0; i < n; i++) inputs[i] -= increment_pattern[i];
261 update_increment_pattern();
265 void compressed_input::forward_init(
Args<> &args)
const {
267 inputs.resize(input_size());
268 for (
size_t i = 0; i < inputs.size(); i++) inputs[i] = args.
input(i);
269 args.
inputs = inputs.data();
273 void compressed_input::reverse_init(
Args<> &args) {
274 inputs.resize(input_size());
275 for (
size_t i = 0; i < inputs.size(); i++)
276 inputs[i] = args.
input(i) + input_diff[i];
278 args.
inputs = inputs.data();
280 args.
ptr.second += m * nrep;
282 update_increment_pattern();
283 args.
ptr.first = input_size();
286 void compressed_input::dependencies_intervals(
Args<> &args,
287 std::vector<Index> &lower,
288 std::vector<Index> &upper)
const {
292 for (
size_t i = 0; i < nrep; i++) {
293 for (
size_t j = 0; j < inputs.size(); j++) {
294 if (inputs[j] < lower[j]) lower[j] = inputs[j];
295 if (inputs[j] > upper[j]) upper[j] = inputs[j];
301 bool compressed_input::test_period(std::vector<ptrdiff_t> &x,
size_t p) {
302 for (
size_t j = 0; j < x.size(); j++) {
303 if (x[j] != x[j % p])
return false;
308 size_t compressed_input::find_shortest(std::vector<ptrdiff_t> &x) {
309 for (
size_t p = 1; p < max_period_size; p++) {
310 if (test_period(x, p))
return p;
315 compressed_input::compressed_input() {}
317 compressed_input::compressed_input(std::vector<Index> &x,
size_t offset,
318 size_t nrow,
size_t m,
size_t ncol,
319 size_t max_period_size)
320 : n(nrow), m(m), nrep(ncol), counter(0), max_period_size(max_period_size) {
321 matrix_view<Index> xm(&x[offset], nrow, ncol);
323 for (
size_t i = 0; i < nrow; i++) {
324 std::vector<ptrdiff_t> rd = xm.row_diff<ptrdiff_t>(i);
326 size_t p = find_shortest(rd);
328 increment_pattern.push_back(rd[0]);
330 which_periodic.push_back(i);
331 period_sizes.push_back(p);
333 size_t pos = std::search(period_data.begin(), period_data.end(),
334 rd.begin(), rd.begin() + p) -
336 if (pos < period_data.size()) {
337 period_offsets.push_back(pos);
339 period_offsets.push_back(period_data.size());
340 period_data.insert(period_data.end(), rd.begin(), rd.begin() + p);
345 np = which_periodic.size();
347 input_diff.resize(n, 0);
350 for (
size_t i = 0; i < nrep; i++) {
357 size_t max_period_size) {
358 opstack.resize(p.
size);
360 for (
size_t i = 0; i < p.
size; i++) {
362 n += opstack[i]->input_size();
363 m += opstack[i]->output_size();
365 ci = compressed_input(glob->
inputs, ptr.first, n, m, p.
rep, max_period_size);
368 StackOp::StackOp(
const StackOp &x) : opstack(x.opstack), ci(x.ci) {}
371 std::vector<const char *> tmp(opstack.size());
372 for (
size_t i = 0; i < opstack.size(); i++) tmp[i] = opstack[i]->op_name();
373 Rcout << cfg.prefix <<
" opstack = " << tmp <<
"\n";
375 Rcout << cfg.prefix <<
" " 377 <<
" = " << ci.nrep <<
"\n";
379 Rcout << cfg.prefix <<
" " 380 <<
"increment_pattern" 381 <<
" = " << ci.increment_pattern <<
"\n";
383 if (ci.which_periodic.size() > 0) {
384 Rcout << cfg.prefix <<
" " 386 <<
" = " << ci.which_periodic <<
"\n";
388 Rcout << cfg.prefix <<
" " 390 <<
" = " << ci.period_sizes <<
"\n";
392 Rcout << cfg.prefix <<
" " 394 <<
" = " << ci.period_offsets <<
"\n";
396 Rcout << cfg.prefix <<
" " 398 <<
" = " << ci.period_data <<
"\n";
405 Index StackOp::input_size()
const {
return ci.n; }
407 Index StackOp::output_size()
const {
return ci.m * ci.nrep; }
410 size_t n = ci.n, m = ci.m, nrep = ci.nrep;
411 std::vector<Index> inputs(n);
412 for (
size_t i = 0; i < (size_t)n; i++) inputs[i] = args.
input(i);
413 std::vector<Index> outputs(m);
414 for (
size_t i = 0; i < (size_t)m; i++) outputs[i] = args.
output(i);
416 size_t np = ci.which_periodic.size();
417 size_t sp = ci.period_data.size();
418 w <<
"for (int count = 0, ";
420 w <<
"i[" << n <<
"]=" << inputs <<
", " 421 <<
"ip[" << n <<
"]=" << ci.increment_pattern <<
", ";
424 w <<
"wp[" << np <<
"]=" << ci.which_periodic <<
", " 425 <<
"ps[" << np <<
"]=" << ci.period_sizes <<
", " 426 <<
"po[" << np <<
"]=" << ci.period_offsets <<
", " 427 <<
"pd[" << sp <<
"]=" << ci.period_data <<
", ";
429 w <<
"o[" << m <<
"]=" << outputs <<
"; " 430 <<
"count < " << nrep <<
"; count++) {\n";
434 args_cpy.set_indirect();
435 for (
size_t k = 0; k < opstack.size(); k++) {
436 opstack[k]->forward_incr(args_cpy);
442 for (
size_t k = 0; k < np; k++)
443 w <<
"ip[wp[" << k <<
"]] = pd[po[" << k <<
"] + count % ps[" << k
449 for (
size_t k = 0; k < n; k++) w <<
"i[" << k <<
"] += ip[" << k <<
"]; ";
453 for (
size_t k = 0; k < m; k++) w <<
"o[" << k <<
"] += " << m <<
"; ";
461 size_t n = ci.n, m = ci.m, nrep = ci.nrep;
462 std::vector<ptrdiff_t> inputs(input_size());
463 for (
size_t i = 0; i < inputs.size(); i++) {
465 if (-ci.input_diff[i] < ci.input_diff[i]) {
466 tmp = -((ptrdiff_t)-ci.input_diff[i]);
468 tmp = ci.input_diff[i];
470 inputs[i] = args.
input(i) + tmp;
472 std::vector<Index> outputs(ci.m);
473 for (
size_t i = 0; i < (size_t)ci.m; i++)
474 outputs[i] = args.
output(i) + ci.m * ci.nrep;
476 size_t np = ci.which_periodic.size();
477 size_t sp = ci.period_data.size();
478 w <<
"for (int count = " << nrep <<
", ";
480 w <<
"i[" << n <<
"]=" << inputs <<
", " 481 <<
"ip[" << n <<
"]=" << ci.increment_pattern <<
", ";
484 w <<
"wp[" << np <<
"]=" << ci.which_periodic <<
", " 485 <<
"ps[" << np <<
"]=" << ci.period_sizes <<
", " 486 <<
"po[" << np <<
"]=" << ci.period_offsets <<
", " 487 <<
"pd[" << sp <<
"]=" << ci.period_data <<
", ";
489 w <<
"o[" << m <<
"]=" << outputs <<
"; " 490 <<
"count > 0 ; ) {\n";
496 for (
size_t k = 0; k < np; k++)
497 w <<
"ip[wp[" << k <<
"]] = pd[po[" << k <<
"] + count % ps[" << k
503 for (
size_t k = 0; k < n; k++) w <<
"i[" << k <<
"] -= ip[" << k <<
"]; ";
507 for (
size_t k = 0; k < m; k++) w <<
"o[" << k <<
"] -= " << m <<
"; ";
513 args_cpy.set_indirect();
514 args_cpy.
ptr.first = ci.n;
515 args_cpy.
ptr.second = ci.m;
516 for (
size_t k = opstack.size(); k > 0;) {
518 opstack[k]->reverse_decr(args_cpy);
526 void StackOp::dependencies(
Args<> args, Dependencies &dep)
const {
527 std::vector<Index> lower;
528 std::vector<Index> upper;
529 ci.dependencies_intervals(args, lower, upper);
530 for (
size_t i = 0; i < lower.size(); i++) {
531 dep.add_interval(lower[i], upper[i]);
535 const char *StackOp::op_name() {
return "StackOp"; }
545 std::vector<Index> remap = radix::first_occurance<Index>(h);
550 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
552 glob.
opstack[i]->dependencies(args, dep);
554 Index var = args.
ptr.second;
555 toposort_remap<Index> fb(remap, var);
560 std::vector<Index> ord = radix::order<Index>(remap);
561 std::vector<Index> v2o = glob.
var2op();
562 glob.subgraph_seq =
subset(v2o, ord);
568 std::vector<Index> remap(glob.
values.size(), Index(-1));
570 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
572 glob.
opstack[i]->dependencies(args, dep);
574 Index var = args.ptr.second;
575 temporaries_remap<Index> fb(remap, var);
577 glob.
opstack[i]->increment(args.ptr);
580 for (
size_t i = remap.size(); i > 0;) {
582 if (remap[i] == Index(-1))
585 remap[i] = remap[remap[i]];
588 std::vector<Index> ord = radix::order<Index>(remap);
589 std::vector<Index> v2o = glob.
var2op();
590 glob.subgraph_seq =
subset(v2o, ord);
596 std::vector<bool> visited(glob.
opstack.size(),
false);
597 std::vector<Index> v2o = glob.
var2op();
598 std::vector<Index> stack;
599 std::vector<Index> result;
602 for (
size_t k = 0; k < glob.
dep_index.size(); k++) {
604 Index i = v2o[dep_var];
608 while (stack.size() > 0) {
609 Index i = stack.back();
610 args.
ptr = glob.subgraph_ptr[i];
612 glob.
opstack[i]->dependencies(args, dep);
613 dfs_add_to_stack<Index> add_to_stack(stack, visited, v2o);
614 size_t before = stack.size();
615 dep.apply(add_to_stack);
616 size_t after = stack.size();
617 if (before == after) {
624 glob.subgraph_seq = result;
630 void compress(
global &glob,
size_t max_period_size) {
631 size_t min_period_rep = TMBAD_MIN_PERIOD_REP;
634 std::vector<period> periods = p.find_all();
636 std::vector<period> periods_expand;
637 for (
size_t i = 0; i < periods.size(); i++) {
638 std::vector<period> tmp =
split_period(&glob, periods[i], max_period_size);
640 if (tmp.size() > 10) {
642 tmp.push_back(periods[i]);
645 for (
size_t j = 0; j < tmp.size(); j++) {
646 if (tmp[j].rep > 1) periods_expand.push_back(tmp[j]);
650 std::swap(periods, periods_expand);
654 for (
size_t i = 0; i < periods.size(); i++) {
656 TMBAD_ASSERT(p.
rep >= 1);
657 while (k < p.
begin) {
658 glob.
opstack[k]->increment(ptr);
663 get_glob()->getOperator<StackOp>(&glob, p, ptr, max_period_size);
665 for (
size_t j = 0; j < p.size * p.rep; j++) {
666 ninp += glob.opstack[p.begin + j]->input_size();
667 glob.opstack[p.begin + j]->deallocate();
668 glob.opstack[p.begin + j] = null_op;
670 glob.opstack[p.begin] = pOp;
672 glob.opstack[p.begin + 1] =
676 std::vector<bool> marks(glob.
values.size(),
true);
682 #include "global.hpp" 685 global *global_ptr_data[TMBAD_MAX_NUM_THREADS] = {NULL};
686 global **global_ptr = global_ptr_data;
687 std::ostream *Writer::cout = 0;
688 bool global::fuse = 0;
692 Dependencies::Dependencies() {}
694 void Dependencies::clear() {
699 void Dependencies::add_interval(Index a, Index b) {
700 I.push_back(std::pair<Index, Index>(a, b));
703 void Dependencies::add_segment(Index start, Index size) {
704 if (size > 0) add_interval(start, start + size - 1);
707 void Dependencies::monotone_transform_inplace(
const std::vector<Index> &x) {
708 for (
size_t i = 0; i < this->size(); i++) (*
this)[i] = x[(*this)[i]];
709 for (
size_t i = 0; i < I.size(); i++) {
710 I[i].first = x[I[i].first];
711 I[i].second = x[I[i].second];
715 bool Dependencies::any(
const std::vector<bool> &x)
const {
716 for (
size_t i = 0; i < this->size(); i++)
717 if (x[(*
this)[i]])
return true;
718 for (
size_t i = 0; i < I.size(); i++) {
719 for (Index j = I[i].first; j <= I[i].second; j++) {
720 if (x[j])
return true;
726 std::string tostr(
const Index &x) {
727 std::ostringstream strs;
732 std::string tostr(
const Scalar &x) {
733 std::ostringstream strs;
738 Writer::Writer(std::string str) : std::string(str) {}
740 Writer::Writer(Scalar x) : std::string(tostr(x)) {}
744 std::string Writer::p(std::string x) {
return "(" + x +
")"; }
746 Writer Writer::operator+(
const Writer &other) {
747 return p(*
this +
" + " + other);
750 Writer Writer::operator-(
const Writer &other) {
751 return p(*
this +
" - " + other);
754 Writer Writer::operator-() {
return " - " + *
this; }
758 Writer Writer::operator/(
const Writer &other) {
return *
this +
" / " + other; }
761 return *
this +
"*" + tostr(other);
764 Writer Writer::operator+(
const Scalar &other) {
765 return p(*
this +
"+" + tostr(other));
768 void Writer::operator=(
const Writer &other) {
769 *cout << *
this +
" = " + other <<
";";
772 void Writer::operator+=(
const Writer &other) {
773 *cout << *
this +
" += " + other <<
";";
776 void Writer::operator-=(
const Writer &other) {
777 *cout << *
this +
" -= " + other <<
";";
780 void Writer::operator*=(
const Writer &other) {
781 *cout << *
this +
" *= " + other <<
";";
784 void Writer::operator/=(
const Writer &other) {
785 *cout << *
this +
" /= " + other <<
";";
788 Position::Position(Index node, Index first, Index second)
789 : node(node), ptr(first, second) {}
791 Position::Position() : node(0), ptr(0, 0) {}
793 bool Position::operator<(
const Position &other)
const {
794 return this->node < other.node;
799 size_t graph::num_neighbors(Index node) {
return p[node + 1] - p[node]; }
801 Index *graph::neighbors(Index node) {
return &(j[p[node]]); }
803 bool graph::empty() {
return p.size() == 0; }
805 size_t graph::num_nodes() {
return (empty() ? 0 : p.size() - 1); }
807 void graph::print() {
808 for (
size_t node = 0; node < num_nodes(); node++) {
809 Rcout << node <<
": ";
810 for (
size_t i = 0; i < num_neighbors(node); i++) {
811 Rcout <<
" " << neighbors(node)[i];
817 std::vector<Index> graph::rowcounts() {
818 std::vector<Index> ans(num_nodes());
819 for (
size_t i = 0; i < ans.size(); i++) ans[i] = num_neighbors(i);
823 std::vector<Index> graph::colcounts() {
824 std::vector<Index> ans(num_nodes());
825 for (
size_t i = 0; i < j.size(); i++) ans[j[i]]++;
829 void graph::bfs(
const std::vector<Index> &start, std::vector<bool> &visited,
830 std::vector<Index> &result) {
831 for (
size_t i = 0; i < start.size(); i++) {
832 Index node = start[i];
833 for (
size_t j_ = 0; j_ < num_neighbors(node); j_++) {
834 Index k = neighbors(node)[j_];
843 void graph::search(std::vector<Index> &start,
bool sort_input,
845 if (mark.size() == 0) mark.resize(num_nodes(),
false);
847 search(start, mark, sort_input, sort_output);
849 for (
size_t i = 0; i < start.size(); i++) mark[start[i]] =
false;
852 void graph::search(std::vector<Index> &start, std::vector<bool> &visited,
853 bool sort_input,
bool sort_output) {
856 for (
size_t i = 0; i < start.size(); i++) visited[start[i]] =
true;
858 bfs(start, visited, start);
863 std::vector<Index> graph::boundary(
const std::vector<Index> &subgraph) {
864 if (mark.size() == 0) mark.resize(num_nodes(),
false);
866 std::vector<Index> boundary;
868 for (
size_t i = 0; i < subgraph.size(); i++) mark[subgraph[i]] =
true;
870 bfs(subgraph, mark, boundary);
872 for (
size_t i = 0; i < subgraph.size(); i++) mark[subgraph[i]] =
false;
873 for (
size_t i = 0; i < boundary.size(); i++) mark[boundary[i]] =
false;
878 graph::graph(
size_t num_nodes,
const std::vector<IndexPair> &edges) {
879 std::vector<IndexPair>::const_iterator it;
880 std::vector<Index> row_counts(num_nodes, 0);
881 for (it = edges.begin(); it != edges.end(); it++) {
882 row_counts[it->first]++;
885 p.resize(num_nodes + 1);
887 for (
size_t i = 0; i < num_nodes; i++) {
888 p[i + 1] = p[i] + row_counts[i];
891 std::vector<Index> k(p);
892 j.resize(edges.size());
893 for (it = edges.begin(); it != edges.end(); it++) {
894 j[k[it->first]++] = it->second;
898 op_info::op_info() : code(0) {
899 static_assert(
sizeof(IntRep) * 8 >= op_flag_count,
900 "'IntRep' not wide enough!");
903 op_info::op_info(op_flag f) : code(1 << f) {}
905 bool op_info::test(
op_flag f)
const {
return code & 1 << f; }
917 global::operation_stack::operation_stack() {}
920 (*this).copy_from(other);
931 if (
this != &other) {
933 (*this).copy_from(other);
938 global::operation_stack::~operation_stack() { (*this).clear(); }
940 void global::operation_stack::clear() {
941 if (any.test(op_info::dynamic)) {
942 for (
size_t i = 0; i < (*this).size(); i++) (*
this)[i]->deallocate();
947 void global::operation_stack::copy_from(
const operation_stack &other) {
948 if (other.
any.
test(op_info::dynamic)) {
949 for (
size_t i = 0; i < other.size(); i++) Base::push_back(other[i]->copy());
951 Base::operator=(other);
953 this->any = other.
any;
957 : forward_compiled(NULL),
958 reverse_compiled(NULL),
962 void global::clear() {
968 subgraph_ptr.resize(0);
969 subgraph_seq.resize(0);
973 void global::shrink_to_fit(
double tol) {
974 std::vector<Scalar>().swap(derivs);
975 std::vector<IndexPair>().swap(subgraph_ptr);
976 if (values.size() < tol * values.capacity())
977 std::vector<Scalar>(values).swap(values);
978 if (inputs.size() < tol * inputs.capacity())
979 std::vector<Index>(inputs).swap(inputs);
980 if (opstack.size() < tol * opstack.capacity())
981 std::vector<OperatorPure *>(opstack).swap(opstack);
984 void global::clear_deriv(Position start) {
985 derivs.resize(values.size());
986 std::fill(derivs.begin() + start.ptr.second, derivs.end(), 0);
989 Scalar &global::value_inv(Index i) {
return values[inv_index[i]]; }
991 Scalar &global::deriv_inv(Index i) {
return derivs[inv_index[i]]; }
993 Scalar &global::value_dep(Index i) {
return values[dep_index[i]]; }
995 Scalar &global::deriv_dep(Index i) {
return derivs[dep_index[i]]; }
997 Position global::begin() {
return Position(0, 0, 0); }
999 Position global::end() {
1000 return Position(opstack.size(), inputs.size(), values.size());
1003 CONSTEXPR
bool global::no_filter::operator[](
size_t i)
const {
return true; }
1005 void global::forward(Position start) {
1006 if (forward_compiled != NULL) {
1007 forward_compiled(values.data());
1011 args.
ptr = start.ptr;
1012 forward_loop(args, start.node);
1015 void global::reverse(Position start) {
1016 if (reverse_compiled != NULL) {
1017 reverse_compiled(values.data(), derivs.data());
1021 reverse_loop(args, start.node);
1024 void global::forward_sub() {
1026 forward_loop_subgraph(args);
1029 void global::reverse_sub() {
1031 reverse_loop_subgraph(args);
1034 void global::forward(std::vector<bool> &marks) {
1040 void global::reverse(std::vector<bool> &marks) {
1046 void global::forward_sub(std::vector<bool> &marks,
1047 const std::vector<bool> &node_filter) {
1050 if (node_filter.size() == 0)
1051 forward_loop_subgraph(args);
1053 forward_loop(args, 0, node_filter);
1056 void global::reverse_sub(std::vector<bool> &marks,
1057 const std::vector<bool> &node_filter) {
1060 if (node_filter.size() == 0)
1061 reverse_loop_subgraph(args);
1063 reverse_loop(args, 0, node_filter);
1066 void global::forward_dense(std::vector<bool> &marks) {
1069 for (
size_t i = 0; i < opstack.size(); i++) {
1070 opstack[i]->forward_incr_mark_dense(args);
1078 for (
size_t i = 0; i < opstack.size(); i++) {
1079 if (opstack[i]->info().test(op_info::updating)) {
1081 opstack[i]->dependencies(args, dep);
1083 for (
size_t i = 0; i < dep.I.size(); i++) {
1084 Index a = dep.I[i].first;
1085 Index b = dep.I[i].second;
1086 marked_intervals.
insert(a, b);
1089 opstack[i]->increment(args.
ptr);
1091 return marked_intervals;
1098 subgraph_cache_ptr();
1099 for (
size_t j = 0; j < subgraph_seq.size(); j++) {
1100 Index i = subgraph_seq[j];
1101 args.
ptr = subgraph_ptr[i];
1102 if (opstack[i]->info().test(op_info::updating)) {
1104 opstack[i]->dependencies(args, dep);
1106 for (
size_t i = 0; i < dep.I.size(); i++) {
1107 Index a = dep.I[i].first;
1108 Index b = dep.I[i].second;
1109 marked_intervals.
insert(a, b);
1113 return marked_intervals;
1116 Replay &global::replay::value_inv(Index i) {
return values[orig.inv_index[i]]; }
1118 Replay &global::replay::deriv_inv(Index i) {
return derivs[orig.inv_index[i]]; }
1120 Replay &global::replay::value_dep(Index i) {
return values[orig.dep_index[i]]; }
1122 Replay &global::replay::deriv_dep(Index i) {
return derivs[orig.dep_index[i]]; }
1124 global::replay::replay(
const global &orig,
global &target)
1125 : orig(orig), target(target) {
1126 TMBAD_ASSERT(&orig != &target);
1129 void global::replay::start() {
1131 if (&target != parent_glob) target.
ad_start();
1132 values = std::vector<Replay>(orig.
values.begin(), orig.
values.end());
1135 void global::replay::stop() {
1136 if (&target != parent_glob) target.
ad_stop();
1137 TMBAD_ASSERT(parent_glob ==
get_glob());
1143 void operator()(Index a, Index b) {
1144 Index n = b - a + 1;
1148 } F = {derivs.data()};
1152 void global::replay::clear_deriv() {
1153 derivs.resize(values.size());
1154 std::fill(derivs.begin(), derivs.end(), Replay(0));
1158 add_updatable_derivs(I);
1162 void global::replay::forward(
bool inv_tags,
bool dep_tags, Position start,
1163 const std::vector<bool> &node_filter) {
1164 TMBAD_ASSERT(&target ==
get_glob());
1166 for (
size_t i = 0; i < orig.
inv_index.size(); i++)
1167 value_inv(i).Independent();
1170 if (node_filter.size() > 0) {
1171 TMBAD_ASSERT(node_filter.size() == orig.
opstack.size());
1177 for (
size_t i = 0; i < orig.
dep_index.size(); i++) value_dep(i).Dependent();
1181 void global::replay::reverse(
bool dep_tags,
bool inv_tags, Position start,
1182 const std::vector<bool> &node_filter) {
1183 TMBAD_ASSERT(&target ==
get_glob());
1185 for (
size_t i = 0; i < orig.
dep_index.size(); i++)
1186 deriv_dep(i).Independent();
1189 if (node_filter.size() > 0) {
1190 TMBAD_ASSERT(node_filter.size() == orig.
opstack.size());
1196 std::fill(derivs.begin(), derivs.begin() + start.ptr.second, Replay(0));
1198 for (
size_t i = 0; i < orig.
inv_index.size(); i++) deriv_inv(i).Dependent();
1202 void global::replay::forward_sub() {
1207 void global::replay::reverse_sub() {
1212 void global::replay::clear_deriv_sub() {
1217 add_updatable_derivs(I);
1221 void global::forward_replay(
bool inv_tags,
bool dep_tags) {
1223 global::replay replay(*
this, new_glob);
1225 replay.forward(inv_tags, dep_tags);
1230 void global::subgraph_cache_ptr()
const {
1231 if (subgraph_ptr.size() == opstack.size())
return;
1232 TMBAD_ASSERT(subgraph_ptr.size() < opstack.size());
1233 if (subgraph_ptr.size() == 0) subgraph_ptr.push_back(IndexPair(0, 0));
1234 for (
size_t i = subgraph_ptr.size(); i < opstack.size(); i++) {
1235 IndexPair ptr = subgraph_ptr[i - 1];
1236 opstack[i - 1]->increment(ptr);
1237 subgraph_ptr.push_back(ptr);
1241 void global::set_subgraph(
const std::vector<bool> &marks,
bool append) {
1242 std::vector<Index> v2o = var2op();
1243 if (!append) subgraph_seq.resize(0);
1244 Index previous = (Index)-1;
1245 for (
size_t i = 0; i < marks.size(); i++) {
1246 if (marks[i] && (v2o[i] != previous)) {
1247 subgraph_seq.push_back(v2o[i]);
1253 void global::mark_subgraph(std::vector<bool> &marks) {
1254 TMBAD_ASSERT(marks.size() == values.size());
1255 clear_array_subgraph(marks,
true);
1258 void global::unmark_subgraph(std::vector<bool> &marks) {
1259 TMBAD_ASSERT(marks.size() == values.size());
1260 clear_array_subgraph(marks,
false);
1263 void global::subgraph_trivial() {
1264 subgraph_cache_ptr();
1265 subgraph_seq.resize(0);
1266 for (
size_t i = 0; i < opstack.size(); i++) subgraph_seq.push_back(i);
1269 void global::clear_deriv_sub() { clear_array_subgraph(derivs); }
1271 global global::extract_sub(std::vector<Index> &var_remap,
global new_glob) {
1272 subgraph_cache_ptr();
1273 TMBAD_ASSERT(var_remap.size() == 0 || var_remap.size() == values.size());
1274 var_remap.resize(values.size(), 0);
1275 std::vector<bool> independent_variable = inv_marks();
1276 std::vector<bool> dependent_variable = dep_marks();
1278 for (
size_t j = 0; j < subgraph_seq.size(); j++) {
1279 Index i = subgraph_seq[j];
1280 args.
ptr = subgraph_ptr[i];
1282 size_t nout = opstack[i]->output_size();
1283 for (
size_t k = 0; k < nout; k++) {
1284 Index new_index = new_glob.
values.size();
1285 Index old_index = args.
output(k);
1286 var_remap[old_index] = new_index;
1287 new_glob.
values.push_back(args.
y(k));
1288 if (independent_variable[old_index]) {
1289 independent_variable[old_index] =
false;
1291 if (dependent_variable[old_index]) {
1292 dependent_variable[old_index] =
false;
1296 size_t nin = opstack[i]->input_size();
1297 for (
size_t k = 0; k < nin; k++) {
1298 new_glob.
inputs.push_back(var_remap[args.
input(k)]);
1304 independent_variable.flip();
1305 dependent_variable.flip();
1307 for (
size_t i = 0; i < inv_index.size(); i++) {
1308 Index old_var = inv_index[i];
1309 if (independent_variable[old_var])
1310 new_glob.
inv_index.push_back(var_remap[old_var]);
1312 for (
size_t i = 0; i < dep_index.size(); i++) {
1313 Index old_var = dep_index[i];
1314 if (dependent_variable[old_var])
1315 new_glob.
dep_index.push_back(var_remap[old_var]);
1320 void global::extract_sub_inplace(std::vector<bool> marks) {
1321 TMBAD_ASSERT(marks.size() == values.size());
1322 std::vector<Index> var_remap(values.size(), 0);
1323 std::vector<bool> independent_variable = inv_marks();
1324 std::vector<bool> dependent_variable = dep_marks();
1327 size_t s = 0, s_input = 0;
1328 std::vector<bool> opstack_deallocate(opstack.size(),
false);
1330 for (
size_t i = 0; i < opstack.size(); i++) {
1331 op_info info = opstack[i]->info();
1333 size_t nout = opstack[i]->output_size();
1334 bool any_marked_output = info.
test(op_info::elimination_protected);
1335 for (
size_t j = 0; j < nout; j++) {
1336 any_marked_output |= args.
y(j);
1338 if (info.
test(op_info::updating) && nout == 0) {
1340 opstack[i]->dependencies_updating(args, dep);
1341 any_marked_output |= dep.any(args.values);
1344 if (any_marked_output) {
1345 for (
size_t k = 0; k < nout; k++) {
1346 Index new_index = s;
1347 Index old_index = args.
output(k);
1348 var_remap[old_index] = new_index;
1349 values[new_index] = values[old_index];
1350 if (independent_variable[old_index]) {
1351 independent_variable[old_index] =
false;
1353 if (dependent_variable[old_index]) {
1354 dependent_variable[old_index] =
false;
1359 size_t nin = opstack[i]->input_size();
1360 for (
size_t k = 0; k < nin; k++) {
1361 inputs[s_input] = var_remap[args.
input(k)];
1365 opstack[i]->increment(args.
ptr);
1366 if (!any_marked_output) {
1367 opstack_deallocate[i] =
true;
1371 independent_variable.flip();
1372 dependent_variable.flip();
1373 std::vector<Index> new_inv_index;
1374 for (
size_t i = 0; i < inv_index.size(); i++) {
1375 Index old_var = inv_index[i];
1376 if (independent_variable[old_var])
1377 new_inv_index.push_back(var_remap[old_var]);
1379 inv_index = new_inv_index;
1380 std::vector<Index> new_dep_index;
1381 for (
size_t i = 0; i < dep_index.size(); i++) {
1382 Index old_var = dep_index[i];
1383 if (dependent_variable[old_var])
1384 new_dep_index.push_back(var_remap[old_var]);
1386 dep_index = new_dep_index;
1388 inputs.resize(s_input);
1391 for (
size_t i = 0; i < opstack.size(); i++) {
1392 if (opstack_deallocate[i]) {
1393 opstack[i]->deallocate();
1395 opstack[k] = opstack[i];
1401 if (opstack.any.test(op_info::dynamic)) this->forward();
1405 std::vector<Index> var_remap;
1406 return extract_sub(var_remap);
1409 std::vector<Index> global::var2op() {
1410 std::vector<Index> var2op(values.size());
1413 for (
size_t i = 0; i < opstack.size(); i++) {
1414 opstack[i]->increment(args.ptr);
1415 for (; j < (size_t)args.ptr.second; j++) {
1422 std::vector<bool> global::var2op(
const std::vector<bool> &values) {
1423 std::vector<bool> ans(opstack.size(),
false);
1426 for (
size_t i = 0; i < opstack.size(); i++) {
1427 opstack[i]->increment(args.ptr);
1428 for (; j < (size_t)args.ptr.second; j++) {
1429 ans[i] = ans[i] || values[j];
1435 std::vector<Index> global::op2var(
const std::vector<Index> &seq) {
1436 std::vector<bool> seq_mark = mark_space(opstack.size(), seq);
1437 std::vector<Index> ans;
1440 for (
size_t i = 0; i < opstack.size(); i++) {
1441 opstack[i]->increment(args.
ptr);
1442 for (; j < (size_t)args.
ptr.second; j++) {
1443 if (seq_mark[i]) ans.push_back(j);
1449 std::vector<bool> global::op2var(
const std::vector<bool> &seq_mark) {
1450 std::vector<bool> ans(values.size());
1453 for (
size_t i = 0; i < opstack.size(); i++) {
1454 opstack[i]->increment(args.ptr);
1455 for (; j < (size_t)args.ptr.second; j++) {
1456 if (seq_mark[i]) ans[j] =
true;
1462 std::vector<Index> global::op2idx(
const std::vector<Index> &var_subset,
1464 std::vector<Index> v2o = var2op();
1465 std::vector<Index> op2idx(opstack.size(), NA);
1466 for (
size_t i = var_subset.size(); i > 0;) {
1468 op2idx[v2o[var_subset[i]]] = i;
1473 std::vector<bool> global::mark_space(
size_t n,
const std::vector<Index> ind) {
1474 std::vector<bool> mark(n,
false);
1475 for (
size_t i = 0; i < ind.size(); i++) {
1476 mark[ind[i]] =
true;
1481 std::vector<bool> global::inv_marks() {
1482 return mark_space(values.size(), inv_index);
1485 std::vector<bool> global::dep_marks() {
1486 return mark_space(values.size(), dep_index);
1489 std::vector<bool> global::subgraph_marks() {
1490 return mark_space(opstack.size(), subgraph_seq);
1493 global::append_edges::append_edges(
size_t &i,
size_t num_nodes,
1494 const std::vector<bool> &keep_var,
1495 std::vector<Index> &var2op,
1496 std::vector<IndexPair> &edges)
1501 op_marks(num_nodes,
false),
1504 void global::append_edges::operator()(Index dep_j) {
1505 if (keep_var[dep_j]) {
1506 size_t k = var2op[dep_j];
1507 if (i != k && !op_marks[k]) {
1512 edges.push_back(edge);
1518 void global::append_edges::start_iteration() { pos = edges.size(); }
1520 void global::append_edges::end_iteration() {
1521 size_t n = edges.size() - pos;
1522 for (
size_t j = 0; j < n; j++) op_marks[edges[pos + j].first] =
false;
1525 graph global::build_graph(
bool transpose,
const std::vector<bool> &keep_var) {
1526 TMBAD_ASSERT(keep_var.size() == values.size());
1528 std::vector<Index> var2op = this->var2op();
1530 bool any_updating =
false;
1533 std::vector<IndexPair> edges;
1536 append_edges F(i, opstack.size(), keep_var, var2op, edges);
1537 for (; i < opstack.size(); i++) {
1538 any_updating |= opstack[i]->info().test(op_info::updating);
1540 opstack[i]->dependencies(args, dep);
1541 F.start_iteration();
1544 opstack[i]->increment(args.
ptr);
1547 size_t begin = edges.size();
1550 for (; i < opstack.size(); i++) {
1552 opstack[i]->dependencies_updating(args, dep);
1553 F.start_iteration();
1556 opstack[i]->increment(args.
ptr);
1558 for (
size_t j = begin; j < edges.size(); j++)
1559 std::swap(edges[j].first, edges[j].second);
1563 for (
size_t j = 0; j < edges.size(); j++)
1564 std::swap(edges[j].first, edges[j].second);
1567 graph G(opstack.size(), edges);
1569 for (
size_t i = 0; i < inv_index.size(); i++)
1570 G.
inv2op.push_back(var2op[inv_index[i]]);
1571 for (
size_t i = 0; i < dep_index.size(); i++)
1572 G.dep2op.push_back(var2op[dep_index[i]]);
1576 graph global::forward_graph(std::vector<bool> keep_var) {
1577 if (keep_var.size() == 0) {
1578 keep_var.resize(values.size(),
true);
1580 TMBAD_ASSERT(values.size() == keep_var.size());
1581 return build_graph(
false, keep_var);
1584 graph global::reverse_graph(std::vector<bool> keep_var) {
1585 if (keep_var.size() == 0) {
1586 keep_var.resize(values.size(),
true);
1588 TMBAD_ASSERT(values.size() == keep_var.size());
1589 return build_graph(
true, keep_var);
1592 bool global::identical(
const global &other)
const {
1593 if (inv_index != other.
inv_index)
return false;
1595 if (dep_index != other.
dep_index)
return false;
1597 if (opstack.size() != other.
opstack.size())
return false;
1599 for (
size_t i = 0; i < opstack.size(); i++) {
1600 if (opstack[i]->identifier() != other.
opstack[i]->identifier())
1604 if (inputs != other.
inputs)
return false;
1606 if (values.size() != other.
values.size())
return false;
1609 IndexPair ptr(0, 0);
1610 for (
size_t i = 0; i < opstack.size(); i++) {
1611 if (opstack[i] == constant) {
1612 if (values[ptr.second] != other.
values[ptr.second])
return false;
1615 opstack[i]->increment(ptr);
1621 hash_t global::hash()
const {
1624 hash(h, inv_index.size());
1626 for (
size_t i = 0; i < inv_index.size(); i++) hash(h, inv_index[i]);
1629 hash(h, dep_index.size());
1631 for (
size_t i = 0; i < dep_index.size(); i++) hash(h, dep_index[i]);
1634 hash(h, opstack.size());
1636 for (
size_t i = 0; i < opstack.size(); i++) hash(h, opstack[i]);
1639 hash(h, inputs.size());
1641 for (
size_t i = 0; i < inputs.size(); i++) hash(h, inputs[i]);
1644 hash(h, values.size());
1647 IndexPair ptr(0, 0);
1648 for (
size_t i = 0; i < opstack.size(); i++) {
1649 if (opstack[i] == constant) {
1650 hash(h, values[ptr.second]);
1653 opstack[i]->increment(ptr);
1660 std::vector<Index> opstack_id;
1662 std::vector<size_t> tmp(opstack.size());
1663 for (
size_t i = 0; i < tmp.size(); i++)
1664 tmp[i] = (
size_t)opstack[i]->identifier();
1665 opstack_id = radix::first_occurance<Index>(tmp);
1666 hash_t spread = (hash_t(1) << (
sizeof(hash_t) * 4)) - 1;
1667 for (
size_t i = 0; i < opstack_id.size(); i++)
1668 opstack_id[i] = (opstack_id[i] + 1) * spread;
1671 std::vector<hash_t> hash_vec(values.size(), 37);
1677 bool have_inv_seed = (cfg.
inv_seed.size() > 0);
1678 if (have_inv_seed) {
1679 TMBAD_ASSERT(cfg.
inv_seed.size() == inv_index.size());
1681 for (
size_t i = 0; i < inv_index.size(); i++) {
1682 hash_vec[inv_index[i]] += (have_inv_seed ? cfg.
inv_seed[i] + 1 : (i + 1));
1687 IndexPair &ptr = args.
ptr;
1688 for (
size_t i = 0; i < opstack.size(); i++) {
1689 if (opstack[i] == inv) {
1690 opstack[i]->increment(ptr);
1695 opstack[i]->dependencies(args, dep);
1698 for (
size_t j = 0; j < dep.size(); j++) {
1700 h = hash_vec[dep[0]];
1702 hash(h, hash_vec[dep[j]]);
1707 hash(h, opstack[i]->identifier());
1710 hash(h, opstack_id[i]);
1715 hash(h, values[ptr.second]);
1718 hash(h, values[ptr.second] > 0);
1722 size_t noutput = opstack[i]->output_size();
1723 for (
size_t j = 0; j < noutput; j++) {
1727 opstack[i]->increment(ptr);
1729 if (!cfg.
reduce)
return hash_vec;
1730 std::vector<hash_t> ans(dep_index.size());
1731 for (
size_t j = 0; j < dep_index.size(); j++) {
1732 ans[j] = hash_vec[dep_index[j]];
1737 std::vector<hash_t> global::hash_sweep(
bool weak)
const {
1744 return hash_sweep(cfg);
1747 void global::eliminate() {
1748 this->shrink_to_fit();
1750 std::vector<bool> marks;
1751 marks.resize(values.size(),
false);
1753 for (
size_t i = 0; i < inv_index.size(); i++) marks[inv_index[i]] =
true;
1754 for (
size_t i = 0; i < dep_index.size(); i++) marks[dep_index[i]] =
true;
1759 set_subgraph(marks);
1761 *
this = extract_sub();
1763 this->extract_sub_inplace(marks);
1764 this->shrink_to_fit();
1767 global::print_config::print_config() : prefix(
""), mark(
"*"), depth(0) {}
1773 IndexPair ptr(0, 0);
1774 std::vector<bool> sgm = subgraph_marks();
1775 bool have_subgraph = (subgraph_seq.size() > 0);
1779 cfg2.prefix = cfg.prefix +
"##";
1780 Rcout << cfg.prefix;
1781 Rcout << setw(7) <<
"OpName:" << setw(7 + have_subgraph)
1782 <<
"Node:" << setw(13) <<
"Value:" << setw(13) <<
"Deriv:" << setw(13)
1787 for (
size_t i = 0; i < opstack.size(); i++) {
1788 Rcout << cfg.prefix;
1789 Rcout << setw(7) << opstack[i]->op_name();
1790 if (have_subgraph) {
1796 Rcout << setw(7) << i;
1797 int numvar = opstack[i]->output_size();
1798 for (
int j = 0; j < numvar + (numvar == 0); j++) {
1799 if (j > 0) Rcout << cfg.prefix;
1800 Rcout << setw((7 + 7) * (j > 0) + 13);
1807 if (derivs.size() == values.size())
1821 IndexPair ptr_old = ptr;
1822 opstack[i]->increment(ptr);
1823 int ninput = ptr.first - ptr_old.first;
1824 for (
int k = 0; k < ninput; k++) {
1825 if (k == 0) Rcout <<
" ";
1826 Rcout <<
" " << inputs[ptr_old.first + k];
1834 if (cfg.depth > 0) opstack[i]->print(cfg2);
1840 global::DynamicInputOutputOperator::DynamicInputOutputOperator(Index ninput,
1842 : ninput_(ninput), noutput_(noutput) {}
1844 Index global::DynamicInputOutputOperator::input_size()
const {
1845 return this->ninput_;
1848 Index global::DynamicInputOutputOperator::output_size()
const {
1849 return this->noutput_;
1852 const char *global::InvOp::op_name() {
return "InvOp"; }
1854 const char *global::DepOp::op_name() {
return "DepOp"; }
1857 args.
y(0).addToTape();
1860 const char *global::ConstOp::op_name() {
return "ConstOp"; }
1863 if (args.const_literals) {
1864 args.
y(0) = args.y_const(0);
1868 global::DataOp::DataOp(Index n) { Base::noutput = n; }
1870 const char *global::DataOp::op_name() {
return "DataOp"; }
1874 global::ZeroOp::ZeroOp(Index n) { Base::noutput = n; }
1876 const char *global::ZeroOp::op_name() {
return "ZeroOp"; }
1880 void global::ZeroOp::operator()(Replay *x, Index n) {
1883 for (
size_t i = 0; i < n; i++) x[i] = y[i];
1886 global::NullOp::NullOp() {}
1888 const char *global::NullOp::op_name() {
return "NullOp"; }
1890 global::NullOp2::NullOp2(Index ninput, Index noutput)
1891 : global::DynamicInputOutputOperator(ninput, noutput) {}
1893 const char *global::NullOp2::op_name() {
return "NullOp2"; }
1895 global::RefOp::RefOp(
global *glob, Index i) : glob(glob), i(i) {}
1916 args.
dx(0) += args.
dy(0);
1920 const char *global::RefOp::op_name() {
return "RefOp"; }
1929 void global::set_fuse(
bool flag) { fuse = flag; }
1933 while (this->opstack.size() > 0) {
1934 OperatorPure *OpTry = this->Fuse(this->opstack.back(), pOp);
1935 if (OpTry == NULL)
break;
1937 this->opstack.pop_back();
1942 this->opstack.push_back(pOp);
1945 bool global::ad_plain::initialized()
const {
return index != NA; }
1947 bool global::ad_plain::on_some_tape()
const {
return initialized(); }
1949 void global::ad_plain::addToTape()
const { TMBAD_ASSERT(initialized()); }
1951 global *global::ad_plain::glob()
const {
1952 return (on_some_tape() ?
get_glob() : NULL);
1955 void global::ad_plain::override_by(
const ad_plain &x)
const {}
1957 global::ad_plain::ad_plain() : index(NA) {}
1959 global::ad_plain::ad_plain(Scalar x) {
1963 global::ad_plain::ad_plain(
ad_aug x) {
1968 Replay global::ad_plain::CopyOp::eval(Replay x0) {
return x0.copy(); }
1970 const char *global::ad_plain::CopyOp::op_name() {
return "CopyOp"; }
1972 ad_plain global::ad_plain::copy()
const {
1977 Replay global::ad_plain::ValOp::eval(Replay x0) {
return x0.copy0(); }
1979 void global::ad_plain::ValOp::dependencies(
Args<> &args,
1980 Dependencies &dep)
const {}
1982 const char *global::ad_plain::ValOp::op_name() {
return "ValOp"; }
1984 ad_plain global::ad_plain::copy0()
const {
1989 ad_plain global::ad_plain::operator+(
const ad_plain &other)
const {
1995 ad_plain global::ad_plain::operator-(
const ad_plain &other)
const {
2012 ad_plain global::ad_plain::operator/(
const ad_plain &other)
const {
2017 const char *global::ad_plain::NegOp::op_name() {
return "NegOp"; }
2019 ad_plain global::ad_plain::operator-()
const {
2024 ad_plain &global::ad_plain::operator+=(
const ad_plain &other) {
2025 *
this = *
this + other;
2029 ad_plain &global::ad_plain::operator-=(
const ad_plain &other) {
2030 *
this = *
this - other;
2034 ad_plain &global::ad_plain::operator*=(
const ad_plain &other) {
2035 *
this = *
this * other;
2039 ad_plain &global::ad_plain::operator/=(
const ad_plain &other) {
2040 *
this = *
this / other;
2044 void global::ad_plain::Dependent() {
2049 void global::ad_plain::Independent() {
2050 Scalar val = (index == NA ? NAN : this->Value());
2055 Scalar &global::ad_plain::Value() {
return get_glob()->
values[index]; }
2057 Scalar global::ad_plain::Value()
const {
return get_glob()->
values[index]; }
2059 Scalar global::ad_plain::Value(
global *glob)
const {
2060 return glob->
values[index];
2063 Scalar &global::ad_plain::Deriv() {
return get_glob()->
derivs[index]; }
2065 void global::ad_start() {
2066 TMBAD_ASSERT2(!in_use,
"Tape already in use");
2067 TMBAD_ASSERT(parent_glob == NULL);
2068 parent_glob = global_ptr[TMBAD_THREAD_NUM];
2069 global_ptr[TMBAD_THREAD_NUM] =
this;
2073 void global::ad_stop() {
2074 TMBAD_ASSERT2(in_use,
"Tape not in use");
2075 global_ptr[TMBAD_THREAD_NUM] = parent_glob;
2080 void global::Independent(std::vector<ad_plain> &x) {
2081 for (
size_t i = 0; i < x.size(); i++) {
2086 global::ad_segment::ad_segment() : n(0), c(0) {}
2097 : x(x), n(r * c), c(c) {}
2101 if (zero_check && all_zero(x, n))
return;
2102 if (all_constant(x, n)) {
2104 size_t m = glob->
values.size();
2107 for (
size_t i = 0; i < n; i++) glob->
values[m + i] = x[i].Value();
2111 if (!is_contiguous(x, n)) {
2113 this->x = x[0].copy();
2114 for (
size_t i = 1; i < n; i++) x[i].copy();
2116 TMBAD_ASSERT2(after - before == n,
2117 "Each invocation of copy() should construct a new variable");
2120 if (n > 0) this->x = x[0];
2123 bool global::ad_segment::identicalZero() {
return !x.initialized(); }
2125 bool global::ad_segment::all_on_active_tape(Replay *x,
size_t n) {
2127 for (
size_t i = 0; i < n; i++) {
2128 bool ok = x[i].on_some_tape() && (x[i].glob() == cur_glob);
2129 if (!ok)
return false;
2134 bool global::ad_segment::is_contiguous(Replay *x,
size_t n) {
2135 if (!all_on_active_tape(x, n))
return false;
2136 for (
size_t i = 1; i < n; i++) {
2137 if (x[i].index() != x[i - 1].index() + 1)
return false;
2142 bool global::ad_segment::all_zero(Replay *x,
size_t n) {
2143 for (
size_t i = 0; i < n; i++) {
2144 if (!x[i].identicalZero())
return false;
2149 bool global::ad_segment::all_constant(Replay *x,
size_t n) {
2150 for (
size_t i = 0; i < n; i++) {
2151 if (!x[i].constant())
return false;
2156 size_t global::ad_segment::size()
const {
return n; }
2158 size_t global::ad_segment::rows()
const {
return n / c; }
2160 size_t global::ad_segment::cols()
const {
return c; }
2162 ad_plain global::ad_segment::operator[](
size_t i)
const {
2164 ans.index = x.index + i;
2168 ad_plain global::ad_segment::offset()
const {
return x; }
2170 Index global::ad_segment::index()
const {
return x.index; }
2175 return on_some_tape() && (this->glob() ==
get_glob());
2182 Index global::ad_aug::index()
const {
return taped_value.index; }
2185 return (on_some_tape() ? data.glob : NULL);
2190 return taped_value.Value(this->data.glob);
2225 while (cur_glob != NULL) {
2226 if (cur_glob == glob)
return true;
2251 return constant() && data.value == Scalar(0);
2255 return constant() && data.value == Scalar(1);
2263 if (
constant() && other.
constant())
return (data.value == other.data.value);
2265 if (
glob() == other.glob())
2271 if (
bothConstant(other))
return Scalar(this->data.value + other.data.value);
2274 return ad_plain(*
this) + ad_plain(other);
2278 if (
bothConstant(other))
return Scalar(this->data.value - other.data.value);
2281 if (this->
identical(other))
return Scalar(0);
2282 return ad_plain(*
this) - ad_plain(other);
2286 if (this->
constant())
return Scalar(-(this->data.value));
2287 return -ad_plain(*
this);
2291 if (
bothConstant(other))
return Scalar(this->data.value * other.data.value);
2296 if (this->
constant())
return ad_plain(other) * Scalar(this->data.value);
2297 if (other.
constant())
return ad_plain(*
this) * Scalar(other.data.value);
2298 return ad_plain(*
this) * ad_plain(other);
2302 if (
bothConstant(other))
return Scalar(this->data.value / other.data.value);
2305 return ad_plain(*
this) / ad_plain(other);
2309 *
this = *
this + other;
2314 *
this = *
this - other;
2319 *
this = *
this * other;
2324 *
this = *
this / other;
2347 Scalar &global::ad_aug::Deriv() {
return taped_value.Deriv(); }
2349 void global::Independent(std::vector<ad_aug> &x) {
2350 for (
size_t i = 0; i < x.size(); i++) {
2355 std::ostream &operator<<(std::ostream &os,
const global::ad_plain &x) {
2360 std::ostream &operator<<(std::ostream &os,
const global::ad_aug &x) {
2363 os <<
"value=" << x.data.glob->values[x.
taped_value.index] <<
", ";
2365 os <<
"tape=" << x.data.glob;
2367 os <<
"const=" << x.data.value;
2373 ad_plain_index::ad_plain_index(
const Index &i) { this->index = i; }
2375 ad_plain_index::ad_plain_index(
const ad_plain &x) : ad_plain(x) {}
2379 ad_aug_index::ad_aug_index(
const ad_aug &x) :
ad_aug(x) {}
2381 ad_aug_index::ad_aug_index(
const ad_plain &x) :
ad_aug(x) {}
2383 Scalar
Value(Scalar x) {
return x; }
2393 bool operator<(
const double &x,
const ad_adapt &y) {
return x < y.Value(); }
2395 bool operator<=(
const double &x,
const ad_adapt &y) {
return x <= y.Value(); }
2397 bool operator>(
const double &x,
const ad_adapt &y) {
return x > y.Value(); }
2399 bool operator>=(
const double &x,
const ad_adapt &y) {
return x >= y.Value(); }
2401 bool operator==(
const double &x,
const ad_adapt &y) {
return x == y.Value(); }
2403 bool operator!=(
const double &x,
const ad_adapt &y) {
return x != y.Value(); }
2405 Writer floor(
const Writer &x) {
2410 const char *FloorOp::op_name() {
return "FloorOp"; }
2411 ad_plain floor(
const ad_plain &x) {
2416 return Scalar(floor(x.
Value()));
2418 return floor(ad_plain(x));
2421 Writer ceil(
const Writer &x) {
2426 const char *CeilOp::op_name() {
return "CeilOp"; }
2430 return Scalar(ceil(x.
Value()));
2432 return ceil(ad_plain(x));
2435 Writer trunc(
const Writer &x) {
2440 const char *TruncOp::op_name() {
return "TruncOp"; }
2441 ad_plain trunc(
const ad_plain &x) {
2446 return Scalar(trunc(x.
Value()));
2448 return trunc(ad_plain(x));
2451 Writer round(
const Writer &x) {
2456 const char *RoundOp::op_name() {
return "RoundOp"; }
2457 ad_plain round(
const ad_plain &x) {
2462 return Scalar(round(x.
Value()));
2464 return round(ad_plain(x));
2467 double sign(
const double &x) {
return (x >= 0) - (x < 0); }
2469 Writer sign(
const Writer &x) {
2474 const char *SignOp::op_name() {
return "SignOp"; }
2478 return Scalar(sign(x.
Value()));
2480 return sign(ad_plain(x));
2483 double ge0(
const double &x) {
return (x >= 0); }
2485 double lt0(
const double &x) {
return (x < 0); }
2487 Writer ge0(
const Writer &x) {
2492 const char *Ge0Op::op_name() {
return "Ge0Op"; }
2496 return Scalar(ge0(x.
Value()));
2498 return ge0(ad_plain(x));
2501 Writer lt0(
const Writer &x) {
2506 const char *Lt0Op::op_name() {
return "Lt0Op"; }
2510 return Scalar(lt0(x.
Value()));
2512 return lt0(ad_plain(x));
2515 Writer fabs(
const Writer &x) {
2521 typedef Scalar Type;
2522 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * sign(args.
x(0));
2524 const char *AbsOp::op_name() {
return "AbsOp"; }
2528 return Scalar(fabs(x.
Value()));
2530 return fabs(ad_plain(x));
2534 Writer sin(
const Writer &x) {
2540 typedef Scalar Type;
2541 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * cos(args.
x(0));
2543 const char *SinOp::op_name() {
return "SinOp"; }
2547 return Scalar(sin(x.
Value()));
2549 return sin(ad_plain(x));
2553 Writer cos(
const Writer &x) {
2559 typedef Scalar Type;
2560 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * -sin(args.
x(0));
2562 const char *CosOp::op_name() {
return "CosOp"; }
2566 return Scalar(cos(x.
Value()));
2568 return cos(ad_plain(x));
2572 Writer exp(
const Writer &x) {
2578 typedef Scalar Type;
2579 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * args.
y(0);
2581 const char *ExpOp::op_name() {
return "ExpOp"; }
2585 return Scalar(exp(x.
Value()));
2587 return exp(ad_plain(x));
2591 Writer log(
const Writer &x) {
2597 typedef Scalar Type;
2598 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * Type(1.) / args.
x(0);
2600 const char *LogOp::op_name() {
return "LogOp"; }
2604 return Scalar(log(x.
Value()));
2606 return log(ad_plain(x));
2610 Writer sqrt(
const Writer &x) {
2616 typedef Scalar Type;
2617 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * Type(0.5) / args.
y(0);
2619 const char *SqrtOp::op_name() {
return "SqrtOp"; }
2623 return Scalar(sqrt(x.
Value()));
2625 return sqrt(ad_plain(x));
2629 Writer tan(
const Writer &x) {
2635 typedef Scalar Type;
2636 if (args.
dy(0) != Type(0))
2637 args.
dx(0) += args.
dy(0) * Type(1.) / (cos(args.
x(0)) * cos(args.
x(0)));
2639 const char *TanOp::op_name() {
return "TanOp"; }
2643 return Scalar(tan(x.
Value()));
2645 return tan(ad_plain(x));
2649 Writer sinh(
const Writer &x) {
2655 typedef Scalar Type;
2656 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * cosh(args.
x(0));
2658 const char *SinhOp::op_name() {
return "SinhOp"; }
2662 return Scalar(sinh(x.
Value()));
2664 return sinh(ad_plain(x));
2668 Writer cosh(
const Writer &x) {
2674 typedef Scalar Type;
2675 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * sinh(args.
x(0));
2677 const char *CoshOp::op_name() {
return "CoshOp"; }
2681 return Scalar(cosh(x.
Value()));
2683 return cosh(ad_plain(x));
2687 Writer tanh(
const Writer &x) {
2693 typedef Scalar Type;
2694 if (args.
dy(0) != Type(0))
2695 args.
dx(0) += args.
dy(0) * Type(1.) / (cosh(args.
x(0)) * cosh(args.
x(0)));
2697 const char *TanhOp::op_name() {
return "TanhOp"; }
2701 return Scalar(tanh(x.
Value()));
2703 return tanh(ad_plain(x));
2707 Writer expm1(
const Writer &x) {
2713 typedef Scalar Type;
2714 if (args.
dy(0) != Type(0)) args.
dx(0) += args.
dy(0) * args.
y(0) + Type(1.);
2716 const char *Expm1::op_name() {
return "Expm1"; }
2720 return Scalar(expm1(x.
Value()));
2722 return expm1(ad_plain(x));
2726 Writer log1p(
const Writer &x) {
2732 typedef Scalar Type;
2733 if (args.
dy(0) != Type(0))
2734 args.
dx(0) += args.
dy(0) * Type(1.) / (args.
x(0) + Type(1.));
2736 const char *Log1p::op_name() {
return "Log1p"; }
2740 return Scalar(log1p(x.
Value()));
2742 return log1p(ad_plain(x));
2746 Writer asin(
const Writer &x) {
2752 typedef Scalar Type;
2753 if (args.
dy(0) != Type(0))
2755 args.
dy(0) * Type(1.) / sqrt(Type(1.) - args.
x(0) * args.
x(0));
2757 const char *AsinOp::op_name() {
return "AsinOp"; }
2761 return Scalar(asin(x.
Value()));
2763 return asin(ad_plain(x));
2767 Writer acos(
const Writer &x) {
2773 typedef Scalar Type;
2774 if (args.
dy(0) != Type(0))
2776 args.
dy(0) * Type(-1.) / sqrt(Type(1.) - args.
x(0) * args.
x(0));
2778 const char *AcosOp::op_name() {
return "AcosOp"; }
2782 return Scalar(acos(x.
Value()));
2784 return acos(ad_plain(x));
2788 Writer atan(
const Writer &x) {
2794 typedef Scalar Type;
2795 if (args.
dy(0) != Type(0))
2796 args.
dx(0) += args.
dy(0) * Type(1.) / (Type(1.) + args.
x(0) * args.
x(0));
2798 const char *AtanOp::op_name() {
return "AtanOp"; }
2802 return Scalar(atan(x.
Value()));
2804 return atan(ad_plain(x));
2808 Writer asinh(
const Writer &x) {
2814 typedef Scalar Type;
2815 if (args.
dy(0) != Type(0))
2817 args.
dy(0) * Type(1.) / sqrt(args.
x(0) * args.
x(0) + Type(1.));
2819 const char *AsinhOp::op_name() {
return "AsinhOp"; }
2820 ad_plain asinh(
const ad_plain &x) {
2825 return Scalar(asinh(x.
Value()));
2827 return asinh(ad_plain(x));
2831 Writer acosh(
const Writer &x) {
2837 typedef Scalar Type;
2838 if (args.
dy(0) != Type(0))
2840 args.
dy(0) * Type(1.) / sqrt(args.
x(0) * args.
x(0) - Type(1.));
2842 const char *AcoshOp::op_name() {
return "AcoshOp"; }
2843 ad_plain acosh(
const ad_plain &x) {
2848 return Scalar(acosh(x.
Value()));
2850 return acosh(ad_plain(x));
2854 Writer atanh(
const Writer &x) {
2860 typedef Scalar Type;
2861 if (args.
dy(0) != Type(0))
2862 args.
dx(0) += args.
dy(0) * Type(1.) / (Type(1) - args.
x(0) * args.
x(0));
2864 const char *AtanhOp::op_name() {
return "AtanhOp"; }
2865 ad_plain atanh(
const ad_plain &x) {
2870 return Scalar(atanh(x.
Value()));
2872 return atanh(ad_plain(x));
2876 Writer pow(
const Writer &x1,
const Writer &x2) {
2879 x1 +
"," + x2 +
")";
2881 const char *PowOp::op_name() {
return "PowOp"; }
2882 ad_plain pow(
const ad_plain &x1,
const ad_plain &x2) {
2889 return pow(ad_plain(x1), ad_plain(x2));
2895 Writer atan2(
const Writer &x1,
const Writer &x2) {
2898 x1 +
"," + x2 +
")";
2900 const char *Atan2::op_name() {
return "Atan2"; }
2901 ad_plain atan2(
const ad_plain &x1,
const ad_plain &x2) {
2906 return Scalar(atan2(x1.
Value(), x2.
Value()));
2908 return atan2(ad_plain(x1), ad_plain(x2));
2914 Writer max(
const Writer &x1,
const Writer &x2) {
2917 x1 +
"," + x2 +
")";
2919 const char *MaxOp::op_name() {
return "MaxOp"; }
2920 ad_plain max(
const ad_plain &x1,
const ad_plain &x2) {
2927 return max(ad_plain(x1), ad_plain(x2));
2933 Writer min(
const Writer &x1,
const Writer &x2) {
2936 x1 +
"," + x2 +
")";
2938 const char *MinOp::op_name() {
return "MinOp"; }
2939 ad_plain min(
const ad_plain &x1,
const ad_plain &x2) {
2946 return min(ad_plain(x1), ad_plain(x2));
2952 if (args.
x(0) == args.
x(1)) {
2953 args.
y(0) = args.
x(2);
2955 args.
y(0) = args.
x(3);
2959 if (args.
x(0) == args.
x(1)) {
2960 args.
dx(2) += args.
dy(0);
2962 args.
dx(3) += args.
dy(0);
2966 args.
y(0) = CondExpEq(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
2970 args.
dx(2) += CondExpEq(args.
x(0), args.
x(1), args.
dy(0), zero);
2971 args.
dx(3) += CondExpEq(args.
x(0), args.
x(1), zero, args.
dy(0));
2975 w <<
"if (" << args.
x(0) <<
"==" << args.
x(1) <<
") ";
2976 args.
y(0) = args.
x(2);
2978 args.
y(0) = args.
x(3);
2982 w <<
"if (" << args.
x(0) <<
"==" << args.
x(1) <<
") ";
2983 args.
dx(2) += args.
dy(0);
2985 args.
dx(3) += args.
dy(0);
2987 const char *CondExpEqOp::op_name() {
2991 Scalar CondExpEq(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
2998 ad_plain CondExpEq(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
2999 const ad_plain &x3) {
3001 std::vector<ad_plain> x(4);
3017 return CondExpEq(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3021 if (args.
x(0) != args.
x(1)) {
3022 args.
y(0) = args.
x(2);
3024 args.
y(0) = args.
x(3);
3028 if (args.
x(0) != args.
x(1)) {
3029 args.
dx(2) += args.
dy(0);
3031 args.
dx(3) += args.
dy(0);
3035 args.
y(0) = CondExpNe(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3039 args.
dx(2) += CondExpNe(args.
x(0), args.
x(1), args.
dy(0), zero);
3040 args.
dx(3) += CondExpNe(args.
x(0), args.
x(1), zero, args.
dy(0));
3044 w <<
"if (" << args.
x(0) <<
"!=" << args.
x(1) <<
") ";
3045 args.
y(0) = args.
x(2);
3047 args.
y(0) = args.
x(3);
3051 w <<
"if (" << args.
x(0) <<
"!=" << args.
x(1) <<
") ";
3052 args.
dx(2) += args.
dy(0);
3054 args.
dx(3) += args.
dy(0);
3056 const char *CondExpNeOp::op_name() {
3060 Scalar CondExpNe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3067 ad_plain CondExpNe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3068 const ad_plain &x3) {
3070 std::vector<ad_plain> x(4);
3086 return CondExpNe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3090 if (args.
x(0) > args.
x(1)) {
3091 args.
y(0) = args.
x(2);
3093 args.
y(0) = args.
x(3);
3097 if (args.
x(0) > args.
x(1)) {
3098 args.
dx(2) += args.
dy(0);
3100 args.
dx(3) += args.
dy(0);
3104 args.
y(0) = CondExpGt(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3108 args.
dx(2) += CondExpGt(args.
x(0), args.
x(1), args.
dy(0), zero);
3109 args.
dx(3) += CondExpGt(args.
x(0), args.
x(1), zero, args.
dy(0));
3113 w <<
"if (" << args.
x(0) <<
">" << args.
x(1) <<
") ";
3114 args.
y(0) = args.
x(2);
3116 args.
y(0) = args.
x(3);
3120 w <<
"if (" << args.
x(0) <<
">" << args.
x(1) <<
") ";
3121 args.
dx(2) += args.
dy(0);
3123 args.
dx(3) += args.
dy(0);
3125 const char *CondExpGtOp::op_name() {
3129 Scalar CondExpGt(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3136 ad_plain CondExpGt(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3137 const ad_plain &x3) {
3139 std::vector<ad_plain> x(4);
3155 return CondExpGt(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3159 if (args.
x(0) < args.
x(1)) {
3160 args.
y(0) = args.
x(2);
3162 args.
y(0) = args.
x(3);
3166 if (args.
x(0) < args.
x(1)) {
3167 args.
dx(2) += args.
dy(0);
3169 args.
dx(3) += args.
dy(0);
3173 args.
y(0) = CondExpLt(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3177 args.
dx(2) += CondExpLt(args.
x(0), args.
x(1), args.
dy(0), zero);
3178 args.
dx(3) += CondExpLt(args.
x(0), args.
x(1), zero, args.
dy(0));
3182 w <<
"if (" << args.
x(0) <<
"<" << args.
x(1) <<
") ";
3183 args.
y(0) = args.
x(2);
3185 args.
y(0) = args.
x(3);
3189 w <<
"if (" << args.
x(0) <<
"<" << args.
x(1) <<
") ";
3190 args.
dx(2) += args.
dy(0);
3192 args.
dx(3) += args.
dy(0);
3194 const char *CondExpLtOp::op_name() {
3198 Scalar CondExpLt(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3205 ad_plain CondExpLt(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3206 const ad_plain &x3) {
3208 std::vector<ad_plain> x(4);
3224 return CondExpLt(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3228 if (args.
x(0) >= args.
x(1)) {
3229 args.
y(0) = args.
x(2);
3231 args.
y(0) = args.
x(3);
3235 if (args.
x(0) >= args.
x(1)) {
3236 args.
dx(2) += args.
dy(0);
3238 args.
dx(3) += args.
dy(0);
3242 args.
y(0) = CondExpGe(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3246 args.
dx(2) += CondExpGe(args.
x(0), args.
x(1), args.
dy(0), zero);
3247 args.
dx(3) += CondExpGe(args.
x(0), args.
x(1), zero, args.
dy(0));
3251 w <<
"if (" << args.
x(0) <<
">=" << args.
x(1) <<
") ";
3252 args.
y(0) = args.
x(2);
3254 args.
y(0) = args.
x(3);
3258 w <<
"if (" << args.
x(0) <<
">=" << args.
x(1) <<
") ";
3259 args.
dx(2) += args.
dy(0);
3261 args.
dx(3) += args.
dy(0);
3263 const char *CondExpGeOp::op_name() {
3267 Scalar CondExpGe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3274 ad_plain CondExpGe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3275 const ad_plain &x3) {
3277 std::vector<ad_plain> x(4);
3293 return CondExpGe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3297 if (args.
x(0) <= args.
x(1)) {
3298 args.
y(0) = args.
x(2);
3300 args.
y(0) = args.
x(3);
3304 if (args.
x(0) <= args.
x(1)) {
3305 args.
dx(2) += args.
dy(0);
3307 args.
dx(3) += args.
dy(0);
3311 args.
y(0) = CondExpLe(args.
x(0), args.
x(1), args.
x(2), args.
x(3));
3315 args.
dx(2) += CondExpLe(args.
x(0), args.
x(1), args.
dy(0), zero);
3316 args.
dx(3) += CondExpLe(args.
x(0), args.
x(1), zero, args.
dy(0));
3320 w <<
"if (" << args.
x(0) <<
"<=" << args.
x(1) <<
") ";
3321 args.
y(0) = args.
x(2);
3323 args.
y(0) = args.
x(3);
3327 w <<
"if (" << args.
x(0) <<
"<=" << args.
x(1) <<
") ";
3328 args.
dx(2) += args.
dy(0);
3330 args.
dx(3) += args.
dy(0);
3332 const char *CondExpLeOp::op_name() {
3336 Scalar CondExpLe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3343 ad_plain CondExpLe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3344 const ad_plain &x3) {
3346 std::vector<ad_plain> x(4);
3362 return CondExpLe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
3366 Index SumOp::input_size()
const {
return n; }
3368 Index SumOp::output_size()
const {
return 1; }
3370 SumOp::SumOp(
size_t n) : n(n) {}
3372 const char *SumOp::op_name() {
return "SumOp"; }
3374 Index LogSpaceSumOp::input_size()
const {
return this->n; }
3376 Index LogSpaceSumOp::output_size()
const {
return 1; }
3378 LogSpaceSumOp::LogSpaceSumOp(
size_t n) : n(n) {}
3381 Scalar Max = -INFINITY;
3382 for (
size_t i = 0; i < n; i++) {
3383 if (Max < args.
x(i)) Max = args.
x(i);
3386 for (
size_t i = 0; i < n; i++) {
3387 args.
y(0) += exp(args.
x(i) - Max);
3389 args.
y(0) = Max + log(args.
y(0));
3393 std::vector<ad_plain> x(input_size());
3394 for (Index i = 0; i < input_size(); i++) x[i] = args.
x(i);
3395 args.
y(0) = logspace_sum(x);
3398 const char *LogSpaceSumOp::op_name() {
return "LSSumOp"; }
3400 ad_plain logspace_sum(
const std::vector<ad_plain> &x) {
3405 Index LogSpaceSumStrideOp::number_of_terms()
const {
return stride.size(); }
3407 Index LogSpaceSumStrideOp::input_size()
const {
return number_of_terms(); }
3409 Index LogSpaceSumStrideOp::output_size()
const {
return 1; }
3411 LogSpaceSumStrideOp::LogSpaceSumStrideOp(std::vector<Index> stride,
size_t n)
3412 : stride(stride), n(n) {}
3415 Scalar Max = -INFINITY;
3417 size_t m = stride.size();
3418 std::vector<Scalar *> wrk(m);
3419 Scalar **px = &(wrk[0]);
3420 for (
size_t i = 0; i < m; i++) {
3421 px[i] = args.
x_ptr(i);
3424 for (
size_t i = 0; i < n; i++) {
3425 Scalar s = rowsum(px, i);
3426 if (Max < s) Max = s;
3430 for (
size_t i = 0; i < n; i++) {
3431 Scalar s = rowsum(px, i);
3432 args.
y(0) += exp(s - Max);
3434 args.
y(0) = Max + log(args.
y(0));
3438 std::vector<ad_plain> x(input_size());
3439 for (Index i = 0; i < input_size(); i++) x[i] = args.
x(i);
3440 args.
y(0) = logspace_sum_stride(x, stride, n);
3443 void LogSpaceSumStrideOp::dependencies(
Args<> &args, Dependencies &dep)
const {
3444 for (
size_t j = 0; j < (size_t)number_of_terms(); j++) {
3445 size_t K = n * stride[j];
3446 dep.add_segment(args.
input(j), K);
3450 const char *LogSpaceSumStrideOp::op_name() {
return "LSStride"; }
3453 TMBAD_ASSERT(
false);
3457 TMBAD_ASSERT(
false);
3460 ad_plain logspace_sum_stride(
const std::vector<ad_plain> &x,
3461 const std::vector<Index> &stride,
size_t n) {
3462 TMBAD_ASSERT(x.size() == stride.size());
3468 #include "graph2dot.hpp" 3471 void graph2dot(
global glob,
graph G,
bool show_id, std::ostream &cout) {
3472 cout <<
"digraph graphname {\n";
3473 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
3475 cout << i <<
" [label=\"" << glob.
opstack[i]->op_name() <<
"\"];\n";
3477 cout << i <<
" [label=\"" << glob.
opstack[i]->op_name() <<
" " << i
3480 for (
size_t node = 0; node < G.num_nodes(); node++) {
3481 for (
size_t k = 0; k < G.num_neighbors(node); k++) {
3482 cout << node <<
" -> " << G.neighbors(node)[k] <<
";\n";
3485 for (
size_t i = 0; i < glob.subgraph_seq.size(); i++) {
3486 size_t node = glob.subgraph_seq[i];
3487 cout << node <<
" [style=\"filled\"];\n";
3490 std::vector<Index> v2o = glob.
var2op();
3492 cout <<
"{rank=same;";
3493 for (
size_t i = 0; i < glob.
inv_index.size(); i++) {
3498 cout <<
"{rank=same;";
3499 for (
size_t i = 0; i < glob.
dep_index.size(); i++) {
3507 void graph2dot(
global glob,
bool show_id, std::ostream &cout) {
3509 graph2dot(glob, G, show_id, cout);
3512 void graph2dot(
const char *filename,
global glob,
graph G,
bool show_id) {
3513 std::ofstream myfile;
3514 myfile.open(filename);
3515 graph2dot(glob, G, show_id, myfile);
3519 void graph2dot(
const char *filename,
global glob,
bool show_id) {
3520 std::ofstream myfile;
3521 myfile.open(filename);
3522 graph2dot(glob, show_id, myfile);
3527 #include "graph_transform.hpp" 3530 std::vector<size_t>
which(
const std::vector<bool> &x) {
3531 return which<size_t>(x);
3536 for (
size_t i = 0; i < x.size(); i++) ans *= x[i];
3541 const std::vector<bool> &vars) {
3542 std::vector<bool> boundary(vars);
3543 std::vector<bool> node_filter = glob.
var2op(vars);
3546 for (
size_t i = 0; i < vars.size(); i++) boundary[i] = boundary[i] ^ vars[i];
3553 std::vector<bool> node_subset(opstack.size(),
false);
3554 for (
size_t i = 0; i < opstack.size(); i++) {
3560 std::vector<bool> var_subset = glob.
op2var(node_subset);
3568 node_subset = glob.
var2op(var_subset);
3570 return which<Index>(node_subset);
3574 std::vector<Index> ans;
3576 for (
size_t i = 0; i < opstack.size(); i++) {
3577 if (!strcmp(opstack[i]->op_name(), name)) {
3585 bool inv_tags,
bool dep_tags) {
3587 std::vector<Index> seq2(seq);
3589 OperatorPure *invop = glob.getOperator<global::InvOp>();
3590 for (
size_t i = 0; i < seq2.size(); i++) {
3592 if (inv_tags) TMBAD_ASSERT(op != invop);
3600 std::vector<Index> new_inv = glob.
op2var(seq2);
3601 if (!inv_tags) glob.
inv_index.resize(0);
3602 if (!dep_tags) glob.
dep_index.resize(0);
3610 return substitute(glob, seq, inv_tags, dep_tags);
3618 substitute(glob_tree, boundary,
false,
true);
3623 std::vector<Scalar> x0(n);
3624 for (
size_t i = 0; i < n; i++) x0[i] = glob_tree.
value_inv(i);
3630 std::vector<Scalar> J(n);
3631 for (
size_t i = 0; i < n; i++) J[i] = glob_tree.
deriv_inv(i);
3633 for (
size_t i = 0; i < n; i++) V -= J[i] * x0[i];
3635 std::vector<Index> vars = glob.
op2var(boundary);
3638 std::vector<ad_aug_index> res(vars.begin(), vars.end());
3639 for (
size_t i = 0; i < vars.size(); i++) {
3640 res[i] = res[i] * J[i];
3641 if (i == 0) res[i] += V;
3642 if (!sum_) res[i].Dependent();
3654 TMBAD_ASSERT((sign == 1) || (sign == -1));
3658 for (
size_t i = 0; i < x.size(); i++) y += x[i];
3659 if (sign < 0) y = -y;
3667 opstack_size = glob.
opstack.size();
3670 void old_state::restore() {
3672 while (glob.
opstack.size() > opstack_size) {
3673 Index input_size = glob.
opstack.back()->input_size();
3674 Index output_size = glob.
opstack.back()->output_size();
3677 glob.
opstack.back()->deallocate();
3682 term_info::term_info(
global &glob,
bool do_init) :
glob(glob) {
3683 if (do_init) initialize();
3686 void term_info::initialize(std::vector<Index> inv_remap) {
3687 if (inv_remap.size() == 0) inv_remap.resize(glob.
inv_index.size(), 0);
3688 inv_remap = radix::factor<Index>(inv_remap);
3691 id = radix::factor<Index>(term_ids);
3692 Index max_id = *std::max_element(
id.
begin(),
id.
end());
3693 count.resize(max_id + 1, 0);
3694 for (
size_t i = 0; i <
id.size(); i++) {
3699 gk_config::gk_config()
3700 : debug(
false),
adaptive(
false), nan2zero(
true), ytol(1e-2), dx(1) {}
3704 for (
size_t i = 0; i < bound.size(); i++)
3705 if (mask_[i]) count *= bound[i];
3711 bound.resize(dim, bound_);
3713 mask_.resize(dim, flag);
3717 : pointer(0), bound(bound) {
3718 x.resize(bound.size(), 0);
3719 mask_.resize(bound.size(), flag);
3726 for (
size_t i = 0; i < x.size(); i++) {
3728 if (x[i] < bound[i] - 1) {
3734 pointer -= (bound[i] - 1) * N;
3742 multivariate_index::operator size_t() {
return pointer; }
3753 TMBAD_ASSERT(mask.size() == mask_.size());
3757 size_t clique::clique_size() {
return indices.size(); }
3761 void clique::subset_inplace(
const std::vector<bool> &
mask) {
3762 indices =
subset(indices, mask);
3766 void clique::logsum_init() { logsum.resize(
prod_int(dim)); }
3768 bool clique::empty()
const {
return (indices.size() == 0); }
3770 bool clique::contains(Index i) {
3772 for (
size_t j = 0; j < indices.size(); j++) ans |= (i == indices[j]);
3777 std::vector<ad_plain> &offset, Index &stride) {
3779 for (
size_t k = 0; (k < clique_size()) && (indices[k] < ind); k++) {
3784 size_t nx = mv.
count();
3785 std::vector<bool> mask =
lmatch(super.
indices, this->indices);
3788 std::vector<ad_plain> x(nx);
3789 size_t xa_count = mv.
count();
3791 size_t xi_count = mv.
count();
3793 TMBAD_ASSERT(x.size() == xa_count * xi_count);
3794 for (
size_t i = 0; i < xa_count; i++, ++mv) {
3796 for (
size_t j = 0; j < xi_count; j++, ++mv) {
3797 TMBAD_ASSERT(logsum[j].on_some_tape());
3808 xa_count = mv.
count();
3809 offset.resize(xa_count);
3810 for (
size_t i = 0; i < xa_count; i++, ++mv) {
3815 sr_grid::sr_grid() {}
3817 sr_grid::sr_grid(Scalar a, Scalar b,
size_t n) : x(n), w(n) {
3818 Scalar h = (b - a) / n;
3819 for (
size_t i = 0; i < n; i++) {
3820 x[i] = a + h / 2 + i * h;
3825 sr_grid::sr_grid(
size_t n) {
3826 for (
size_t i = 0; i < n; i++) {
3828 w[i] = 1. / (double)n;
3832 size_t sr_grid::size() {
return x.size(); }
3834 ad_plain sr_grid::logw_offset() {
3835 if (logw.size() != w.size()) {
3836 logw.resize(w.size());
3837 for (
size_t i = 0; i < w.size(); i++) logw[i] = log(w[i]);
3844 std::vector<Index> random,
3845 std::vector<sr_grid> grid,
3846 std::vector<Index> random2grid,
3851 replay(glob, new_glob),
3852 tinfo(glob, false) {
3853 inv2grid.resize(glob.
inv_index.size(), 0);
3854 for (
size_t i = 0; i < random2grid.size(); i++) {
3855 inv2grid[random[i]] = random2grid[i];
3858 mark.resize(glob.
values.size(),
false);
3859 for (
size_t i = 0; i < random.size(); i++)
3868 var_remap.resize(glob.
values.size());
3875 terms_done.resize(glob.
dep_index.size(),
false);
3877 std::vector<Index> inv_remap(glob.
inv_index.size());
3878 for (
size_t i = 0; i < inv_remap.size(); i++) inv_remap[i] = -(i + 1);
3879 for (
size_t i = 0; i < random.size(); i++)
3880 inv_remap[random[i]] = inv2grid[random[i]];
3881 inv_remap = radix::factor<Index>(inv_remap);
3882 tinfo.initialize(inv_remap);
3886 std::vector<IndexPair> edges;
3887 std::vector<Index> &inv2op = forward_graph.
inv2op;
3889 for (
size_t i = 0; i < random.size(); i++) {
3890 std::vector<Index> subgraph(1, inv2op[random[i]]);
3891 forward_graph.
search(subgraph);
3892 reverse_graph.
search(subgraph);
3893 for (
size_t l = 0; l < subgraph.size(); l++) {
3894 Index inv_other = op2inv_idx[subgraph[l]];
3895 if (inv_other != NA) {
3896 IndexPair edge(random[i], inv_other);
3897 edges.push_back(edge);
3902 size_t num_nodes = glob.
inv_index.size();
3903 graph G(num_nodes, edges);
3905 std::vector<bool> visited(num_nodes,
false);
3906 std::vector<Index> subgraph;
3907 for (
size_t i = 0; i < random.size(); i++) {
3908 if (visited[random[i]])
continue;
3909 std::vector<Index> sg(1, random[i]);
3910 G.
search(sg, visited,
false,
false);
3911 subgraph.insert(subgraph.end(), sg.begin(), sg.end());
3913 std::reverse(subgraph.begin(), subgraph.end());
3914 TMBAD_ASSERT(random.size() == subgraph.size());
3918 std::vector<size_t> sequential_reduction::get_grid_bounds(
3919 std::vector<Index> inv_index) {
3920 std::vector<size_t> ans(inv_index.size());
3921 for (
size_t i = 0; i < inv_index.size(); i++) {
3922 ans[i] = grid[inv2grid[inv_index[i]]].size();
3927 std::vector<sr_grid *> sequential_reduction::get_grid(
3928 std::vector<Index> inv_index) {
3929 std::vector<sr_grid *> ans(inv_index.size());
3930 for (
size_t i = 0; i < inv_index.size(); i++) {
3931 ans[i] = &(grid[inv2grid[inv_index[i]]]);
3938 size_t id = tinfo.id[dep_index];
3939 size_t count = tinfo.count[id];
3940 bool do_cache = (count >= 2);
3942 if (cache[
id].size() > 0) {
3947 std::vector<sr_grid *> inv_grid = get_grid(inv_index);
3948 std::vector<size_t> grid_bounds = get_grid_bounds(inv_index);
3950 std::vector<ad_aug> ans(mv.
count());
3951 for (
size_t i = 0; i < ans.size(); i++, ++mv) {
3952 for (
size_t j = 0; j < inv_index.size(); j++) {
3953 replay.value_inv(inv_index[j]) = inv_grid[j]->x[mv.
index(j)];
3955 replay.forward_sub();
3956 ans[i] = replay.value_dep(dep_index);
3967 std::vector<Index> super;
3969 for (std::list<clique>::iterator it = cliques.begin(); it != cliques.end();
3971 if ((*it).contains(i)) {
3972 super.insert(super.end(), (*it).indices.begin(), (*it).indices.end());
3978 std::vector<std::vector<ad_plain> > offset_by_clique(c);
3979 std::vector<Index> stride_by_clique(c);
3982 C.
dim = get_grid_bounds(super);
3983 std::list<clique>::iterator it = cliques.begin();
3985 while (it != cliques.end()) {
3986 if ((*it).contains(i)) {
3987 (*it).get_stride(C, i, offset_by_clique[c], stride_by_clique[c]);
3988 it = cliques.erase(it);
3995 std::vector<bool> mask =
lmatch(super, std::vector<Index>(1, i));
3997 C.subset_inplace(mask);
4000 grid[inv2grid[i]].logw_offset();
4002 for (
size_t j = 0; j < C.
logsum.size(); j++) {
4003 std::vector<ad_plain> x;
4004 std::vector<Index> stride;
4005 for (
size_t k = 0; k < offset_by_clique.size(); k++) {
4006 x.push_back(offset_by_clique[k][j]);
4007 stride.push_back(stride_by_clique[k]);
4010 x.push_back(grid[inv2grid[i]].logw_offset());
4011 stride.push_back(1);
4012 C.
logsum[j] = logspace_sum_stride(x, stride, grid[inv2grid[i]].size());
4015 TMBAD_ASSERT(v_end - v_begin == C.
logsum.size());
4017 cliques.push_back(C);
4021 const std::vector<Index> &inv2op = forward_graph.
inv2op;
4023 Index start_node = inv2op[i];
4024 std::vector<Index> subgraph(1, start_node);
4025 forward_graph.
search(subgraph);
4027 std::vector<Index> dep_clique;
4028 std::vector<Index> subgraph_terms;
4029 for (
size_t k = 0; k < subgraph.size(); k++) {
4030 Index node = subgraph[k];
4031 Index dep_idx = op2dep_idx[node];
4032 if (dep_idx != NA && !terms_done[dep_idx]) {
4033 terms_done[dep_idx] =
true;
4034 subgraph_terms.push_back(node);
4035 dep_clique.push_back(dep_idx);
4038 for (
size_t k = 0; k < subgraph_terms.size(); k++) {
4040 subgraph.push_back(subgraph_terms[k]);
4042 reverse_graph.
search(subgraph);
4044 std::vector<Index> inv_clique;
4045 for (
size_t l = 0; l < subgraph.size(); l++) {
4046 Index tmp = op2inv_idx[subgraph[l]];
4047 if (tmp != NA) inv_clique.push_back(tmp);
4050 glob.subgraph_seq = subgraph;
4054 C.
dim = get_grid_bounds(inv_clique);
4057 cliques.push_back(C);
4063 void sequential_reduction::show_cliques() {
4064 Rcout <<
"Cliques: ";
4065 std::list<clique>::iterator it;
4066 for (it = cliques.begin(); it != cliques.end(); ++it) {
4067 Rcout << it->indices <<
" ";
4072 void sequential_reduction::update_all() {
4073 for (
size_t i = 0; i < random.size(); i++)
update(random[i]);
4076 ad_aug sequential_reduction::get_result() {
4078 std::list<clique>::iterator it;
4079 for (it = cliques.begin(); it != cliques.end(); ++it) {
4080 TMBAD_ASSERT(it->clique_size() == 0);
4081 TMBAD_ASSERT(it->logsum.size() == 1);
4082 ans += it->logsum[0];
4085 for (
size_t i = 0; i < terms_done.size(); i++) {
4086 if (!terms_done[i]) ans += replay.value_dep(i);
4091 global sequential_reduction::marginal() {
4093 replay.forward(
true,
false);
4095 ad_aug ans = get_result();
4101 autopar::autopar(
global &glob,
size_t num_threads)
4103 num_threads(num_threads),
4104 do_aggregate(
false),
4105 keep_all_inv(
false) {
4110 std::vector<Index> max_tree_depth(glob.
opstack.size(), 0);
4113 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4115 glob.
opstack[i]->dependencies(args, dep);
4116 for (
size_t j = 0; j < dep.size(); j++) {
4117 max_tree_depth[i] =
std::max(max_tree_depth[i], max_tree_depth[dep[j]]);
4120 max_tree_depth[i]++;
4124 std::vector<size_t> ans(glob.
dep_index.size());
4125 for (
size_t j = 0; j < glob.
dep_index.size(); j++) {
4126 ans[j] = max_tree_depth[glob.
dep_index[j]];
4131 void autopar::run() {
4132 std::vector<size_t> ord =
order(max_tree_depth());
4133 std::reverse(ord.begin(), ord.end());
4134 std::vector<bool> visited(glob.
opstack.size(),
false);
4135 std::vector<Index> start;
4136 std::vector<Index> dWork(ord.size());
4137 for (
size_t i = 0; i < ord.size(); i++) {
4139 start[0] = reverse_graph.
dep2op[ord[i]];
4140 reverse_graph.
search(start, visited,
false,
false);
4141 dWork[i] = start.size();
4143 for (
size_t k = 0; k < start.size(); k++) {
4144 Rcout << glob.
opstack[start[k]]->op_name() <<
" ";
4150 std::vector<size_t> thread_assign(ord.size(), 0);
4151 std::vector<size_t> work_by_thread(num_threads, 0);
4152 for (
size_t i = 0; i < dWork.size(); i++) {
4154 thread_assign[i] = 0;
4157 thread_assign[i] = thread_assign[i - 1];
4159 thread_assign[i] = which_min(work_by_thread);
4161 work_by_thread[thread_assign[i]] += dWork[i];
4164 node_split.resize(num_threads);
4165 for (
size_t i = 0; i < ord.size(); i++) {
4166 node_split[thread_assign[i]].push_back(reverse_graph.
dep2op[ord[i]]);
4169 for (
size_t i = 0; i < num_threads; i++) {
4171 node_split[i].insert(node_split[i].begin(), reverse_graph.
inv2op.begin(),
4172 reverse_graph.
inv2op.end());
4173 reverse_graph.
search(node_split[i]);
4178 vglob.resize(num_threads);
4179 inv_idx.resize(num_threads);
4180 dep_idx.resize(num_threads);
4181 std::vector<Index> tmp;
4182 for (
size_t i = 0; i < num_threads; i++) {
4183 glob.subgraph_seq = node_split[i];
4191 for (
size_t i = 0; i < num_threads; i++) {
4192 std::vector<Index> &seq = node_split[i];
4193 for (
size_t j = 0; j < seq.size(); j++) {
4194 if (op2inv_idx[seq[j]] != NA) inv_idx[i].push_back(op2inv_idx[seq[j]]);
4195 if (op2dep_idx[seq[j]] != NA) dep_idx[i].push_back(op2dep_idx[seq[j]]);
4198 dep_idx[i].resize(1);
4207 return (do_aggregate ? num_threads : glob.
dep_index.size());
4210 Index ParalOp::input_size()
const {
return n; }
4212 Index ParalOp::output_size()
const {
return m; }
4214 ParalOp::ParalOp(
const autopar &ap)
4222 size_t num_threads = vglob.size();
4225 #pragma omp parallel for 4228 for (
size_t i = 0; i < num_threads; i++) {
4229 for (
size_t j = 0; j < inv_idx[i].size(); j++) {
4230 vglob[i].value_inv(j) = args.
x(inv_idx[i][j]);
4235 for (
size_t i = 0; i < num_threads; i++) {
4236 for (
size_t j = 0; j < dep_idx[i].size(); j++) {
4237 args.
y(dep_idx[i][j]) = vglob[i].value_dep(j);
4243 size_t num_threads = vglob.size();
4246 #pragma omp parallel for 4249 for (
size_t i = 0; i < num_threads; i++) {
4250 vglob[i].clear_deriv();
4251 for (
size_t j = 0; j < dep_idx[i].size(); j++) {
4252 vglob[i].deriv_dep(j) = args.
dy(dep_idx[i][j]);
4257 for (
size_t i = 0; i < num_threads; i++) {
4258 for (
size_t j = 0; j < inv_idx[i].size(); j++) {
4259 args.
dx(inv_idx[i][j]) += vglob[i].deriv_inv(j);
4264 const char *ParalOp::op_name() {
return "ParalOp"; }
4267 size_t num_threads = vglob.size();
4268 for (
size_t i = 0; i < num_threads; i++) {
4270 std::stringstream ss;
4272 std::string str = ss.str();
4273 cfg2.prefix = cfg2.prefix + str;
4274 vglob[i].print(cfg2);
4278 std::vector<Index> get_likely_expression_duplicates(
4279 const global &glob, std::vector<Index> inv_remap) {
4287 std::vector<hash_t> h = glob.
hash_sweep(cfg);
4288 return radix::first_occurance<Index>(h);
4293 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4304 global &glob, std::vector<Index> inv_remap) {
4305 std::vector<Index> remap = get_likely_expression_duplicates(glob, inv_remap);
4307 for (
size_t i = 0; i < glob.
inv_index.size(); i++) {
4308 bool accept =
false;
4310 if (inv_remap.size() > 0) {
4311 Index j = inv_remap[i];
4313 accept = remap[var_i] == remap[var_j];
4315 if (!accept) remap[var_i] = var_i;
4318 std::vector<Index> v2o = glob.
var2op();
4319 std::vector<Index> dep;
4327 for (
size_t j = 0, i = 0, nout = 0; j < glob.
opstack.size(); j++, i += nout) {
4328 nout = glob.
opstack[j]->output_size();
4329 bool any_remap =
false;
4330 for (
size_t k = i; k < i + nout; k++) {
4331 if (remap[k] != k) {
4349 if (ok && (nout > 1)) {
4350 for (
size_t k = 1; k < nout; k++) {
4351 ok &= (remap[i + k] < i);
4353 ok &= (v2o[remap[i + k]] == v2o[remap[i]]);
4355 ok &= (remap[i + k] == remap[i] + k);
4359 if (CurOp == invop) {
4373 args.
ptr = glob.subgraph_ptr[v2o[i]];
4375 glob.
opstack[v2o[i]]->dependencies(args, dep1);
4377 args.
ptr = glob.subgraph_ptr[v2o[remap[i]]];
4379 glob.
opstack[v2o[remap[i]]]->dependencies(args, dep2);
4381 ok = (dep1.size() == dep2.size());
4383 bool all_equal =
true;
4384 for (
size_t j = 0; j < dep1.size(); j++) {
4385 all_equal &= (remap[dep1[j]] == remap[dep2[j]]);
4393 for (
size_t k = i; k < i + nout; k++) remap[k] = k;
4398 for (
size_t i = 0; i < remap.size(); i++) {
4399 TMBAD_ASSERT(remap[i] <= i);
4400 TMBAD_ASSERT(remap[remap[i]] == remap[i]);
4406 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4410 glob.
opstack[i]->dependencies(args, dep);
4411 for (
size_t j = 0; j < dep.I.size(); j++) {
4412 visited.
insert(dep.I[j].first, dep.I[j].second);
4429 std::vector<Index> inv_remap(0);
4432 for (
size_t i = 0; i < glob.
inputs.size(); i++) {
4437 std::vector<Position> inv_positions(
global &glob) {
4438 IndexPair ptr(0, 0);
4439 std::vector<bool> independent_variable = glob.
inv_marks();
4440 std::vector<Position> ans(glob.
inv_index.size());
4442 for (
size_t i = 0; i < glob.
opstack.size(); i++) {
4443 Index nout = glob.
opstack[i]->output_size();
4444 for (Index j = 0; j < nout; j++) {
4445 if (independent_variable[ptr.second + j]) {
4451 glob.
opstack[i]->increment(ptr);
4458 for (
size_t i = 1; i < inv_idx.size(); i++) {
4459 TMBAD_ASSERT(inv_idx[i] > inv_idx[i - 1]);
4461 std::vector<bool> marks(glob.
values.size(),
false);
4462 for (
size_t i = 0; i < inv_idx.size(); i++)
4463 marks[glob.
inv_index[inv_idx[i]]] =
true;
4466 int c = std::count(marks.begin(), marks.end(),
true);
4467 Rcout <<
"marked proportion:" << (double)c / (
double)marks.size() <<
"\n";
4478 #include "integrate.hpp" 4481 double value(
double x) {
return x; }
4483 control::control(
int subdivisions_,
double reltol_,
double abstol_)
4484 : subdivisions(subdivisions_), reltol(reltol_), abstol(abstol_) {}
4487 #include "radix.hpp" 4490 #include "tmbad_allow_comparison.hpp" 4496 bool operator<(
const Scalar &x,
const ad_aug &y) {
return x < y.
Value(); }
4501 bool operator<=(
const Scalar &x,
const ad_aug &y) {
return x <= y.
Value(); }
4506 bool operator>(
const Scalar &x,
const ad_aug &y) {
return x > y.
Value(); }
4511 bool operator>=(
const Scalar &x,
const ad_aug &y) {
return x >= y.
Value(); }
4516 bool operator==(
const Scalar &x,
const ad_aug &y) {
return x == y.
Value(); }
4521 bool operator!=(
const Scalar &x,
const ad_aug &y) {
return x != y.
Value(); }
4527 VSumOp::VSumOp(
size_t n) : n(n) {}
4529 void VSumOp::dependencies(
Args<> &args, Dependencies &dep)
const {
4530 dep.add_segment(args.
input(0), n);
4537 const char *VSumOp::op_name() {
return "VSumOp"; }
4544 Scalar *SegmentRef::value_ptr() {
return (*glob_ptr).values.data() + offset; }
4546 Scalar *SegmentRef::deriv_ptr() {
return (*glob_ptr).derivs.data() + offset; }
4548 SegmentRef::SegmentRef() {}
4550 SegmentRef::SegmentRef(
const Scalar *x) {
4555 SegmentRef::SegmentRef(
global *g, Index o, Index s)
4556 : glob_ptr(g), offset(o), size(s) {}
4558 SegmentRef::SegmentRef(
const ad_segment &x) {
4559 static const size_t K = ScalarPack<SegmentRef>::size;
4560 TMBAD_ASSERT(x.size() == K);
4562 for (
size_t i = 0; i < K; i++) buf[i] = x[i].Value();
4567 bool SegmentRef::isNull() {
return (glob_ptr == NULL); }
4570 Index i = pack.index();
4575 PackOp::PackOp(
const Index n) : n(n) {}
4589 if (tmp.glob_ptr != NULL) {
4592 for (Index i = 0; i < n; i++) dx[i] += dy[i];
4605 Replay *pdx = args.
dx_ptr(0);
4606 for (Index i = 0; i < n; i++) pdx[i] = dx[i];
4609 const char *PackOp::op_name() {
return "PackOp"; }
4611 void PackOp::dependencies(
Args<> &args, Dependencies &dep)
const {
4612 dep.add_segment(args.
input(0), n);
4615 UnpkOp::UnpkOp(
const Index n) : noutput(n) {}
4618 Scalar *y = args.
y_ptr(0);
4621 for (Index i = 0; i < noutput; i++) y[i] = 0;
4624 Scalar *x = srx.value_ptr();
4625 for (Index i = 0; i < noutput; i++) y[i] = x[i];
4638 Replay *pdx = args.
dx_ptr(0);
4639 for (Index i = 0; i < dy_packed.size(); i++) pdx[i] = dy_packed[i];
4642 const char *UnpkOp::op_name() {
return "UnpkOp"; }
4644 void UnpkOp::dependencies(
Args<> &args, Dependencies &dep)
const {
4645 dep.add_segment(args.
input(0), K);
4659 Scalar *
unpack(
const std::vector<Scalar> &x, Index j) {
4660 Index K = ScalarPack<SegmentRef>::size;
4662 return sr.value_ptr();
4665 std::vector<ad_aug> concat(
const std::vector<ad_segment> &x) {
4666 std::vector<ad_aug> ans;
4667 for (
size_t i = 0; i < x.size(); i++) {
4669 for (
size_t j = 0; j < xi.size(); j++) {
4670 ans.push_back(xi[j]);
Automatic differentiation library designed for TMB.
std::vector< Index > op2var(const std::vector< Index > &seq)
Get variables produces by a node seqence.
std::vector< T > subset(const std::vector< T > &x, const std::vector< bool > &y)
Vector subset by boolean mask.
graph reverse_graph(std::vector< bool > keep_var=std::vector< bool >(0))
Construct operator graph with reverse connections.
size_t prod_int(const std::vector< size_t > &x)
Integer product function.
Add zero allocated workspace to the tape.
global accumulation_tree_split(global glob, bool sum_=false)
Split a computational graph by it's accumulation tree.
virtual Index output_size()=0
Number of outputs from this OperatorPure.
void shrink_to_fit(double tol=.9)
Release unnecessary workspace to the system.
Is this a linear operator ?
Does this operator require dynamic allocation ?
Type y(Index j) const
j'th output variable of this operator
Type * y_ptr(Index j)
pointer version - use with caution.
void reorder_temporaries(global &glob)
Re-order computational graph to make it more compressible.
IntRep code
Internal integer representation.
std::vector< size_t > dim
Dimension of logsum array.
void reverse(Position start=Position(0, 0, 0))
Full or partial reverse sweep through the operation stack. Updates global::derivs.
size_t rep
Number of consecutive period replicates.
virtual OperatorPure * other_fuse(OperatorPure *other)=0
Lookup table for operator fusion. Merge this OperatorPure with another operator. If no match return N...
void forward(Position start=Position(0, 0, 0))
Full or partial forward sweep through the operation stack. Updates global::values.
std::vector< Index > inv_seed
Optionally control seeding of InvOp in case strong_inv=true
graph forward_graph(std::vector< bool > keep_var=std::vector< bool >(0))
Construct operator graph with forward connections.
void reorder_graph(global &glob, std::vector< Index > inv_idx)
Reorder computational graph such that selected independent variables come last.
operation_stack opstack
Operation stack.
void make_space_inplace(std::vector< T > &x, std::vector< I > &i, T space=T(0))
Make space for new elements in an existing vector.
void sort_inplace(std::vector< T > &x)
Utility: sort inplace.
ad_aug operator-() const
Negation.
size_t size
Size of the period.
void subgraph_cache_ptr() const
Cache array pointers required by all subgraph routines.
const Index * inputs
Array for indirect access of operator inputs.
segment_ref< ForwardArgs, y_write > y_segment(Index from, Index size)
segment version
Position end()
The three pointers defining the end of the tape.
multivariate_index & operator++()
Advance to next element of this sub slice.
IndexVector inputs
Pointers into global::values determining operator inputs.
global * get_glob()
Get pointer to current global AD context (or NULL if no context is active).
bool bothConstant(const ad_aug &other) const
If true 'this' and 'other' are both constants. If false nothing can be concluded (they might be const...
ad_plain add_to_stack(Scalar result=0)
Add nullary operator to the stack based on its result
void reverse_loop_subgraph(ReverseArgs &args) const
Generic reverse sweep along global::subgraph_seq.
Access input/output values and derivatives during a reverse pass. Write access granted for the input ...
F & apply(F &f) const
Apply a functor to each interval.
std::vector< Scalar > values
Contiguous workspace for taped variables (same length as global::derivs)
Is this a constant operator ?
Contiguous set of variables on the current tape.
Position begin()
The three pointers defining the begining of the tape.
ad_aug operator*(const ad_aug &other) const
Multiplication.
Type * y_ptr(Index j)
pointer version - use with caution.
void reorder_sub_expressions(global &glob)
Re-order computational graph to make it more compressible.
void clear()
Clear the operation stack without freeing the container.
void reorder_depth_first(global &glob)
Depth-first reordering of computational graph.
Access input/output values during a forward pass. Write access granted for the output value only...
bool on_some_tape() const
Is this object on some (not necessarily active) tape?
void ad_stop()
Stop ad calculations from being piped to this glob.
void forceContiguous(V &x)
Make contiguous ad vector.
Scalar & value_dep(Index i)
Reference to i'th component of the function value.
void forward_loop(ForwardArgs &args, size_t begin, const NodeFilter &node_filter) const
Generic forward sweep.
Enable weak comparison operators of an ad type.
void forward_loop_subgraph(ForwardArgs &args) const
Generic forward sweep along global::subgraph_seq.
bool deterministic
Deterministic hash codes?
void set_mask(const std::vector< bool > &mask)
Set all active dimensions of this multivariate_index
bool insert(T a, T b)
Insert new interval [a,b].
std::vector< Index > indices
Variable indices of this clique.
ad_aug operator/(const ad_aug &other) const
Division.
global * glob() const
Get the tape of this ad_aug.
vector< Type > operator*(matrix< Type > A, vector< Type > x)
Configuration of print method.
std::vector< Index > inv2op
Used to lookup operator (node) of an independent variable.
std::vector< Index > op2idx(const std::vector< Index > &var_subset, Index NA=(Index) -1)
General operator -> variable table generator.
The abstract operator for the operation stack global::opstack
virtual op_info info()=0
Get operator info.
Type x(Index j) const
j'th input variable of this operator
std::vector< size_t > max_tree_depth()
Give an estimate (maximum tree depth) of the size of each reverse sub tree.
global extract_sub(std::vector< Index > &var_remap, global new_glob=global())
Extract a subgraph as a new global object. Fast when called many times.
void reverse(ReverseArgs< Scalar > &args)
Unpack derivatives.
bool strong_inv
Use unique code for each independent variable? (see hash_sweep)
ad_aug copy0() const
Deep copy existing ad_aug wihout derivatives.
Type * x_ptr(Index j)
pointer version - use with caution.
std::vector< Index > inv_index
Pointers into global::values determining independent variables.
Representation of a period in a sequence.
ad_aug & operator+=(const ad_aug &other)
Compound assignment taking advantage of operator+ reductions.
Type & y(Index j)
j'th output variable of this operator
bool strong_output
Use unique hash code for each output of an operator?
size_t begin
Where does the period begin.
std::vector< bool > reverse_boundary(global &glob, const std::vector< bool > &vars)
Find boundary without using the graph.
std::vector< Index > get_accumulation_tree(global &glob, bool boundary=false)
Get node indices of the accumulation tree or its boundary.
std::vector< Index > substitute(global &glob, const std::vector< Index > &seq, bool inv_tags=true, bool dep_tags=true)
substitute node index sequence by independent variables
Type x(Index j) const
j'th input variable of this operator
void(* reverse_compiled)(Scalar *, Scalar *)
Optional pointer to compiled code.
op_info any
Bitwise max of operator flags in this stack.
void sort_unique_inplace(std::vector< T > &x)
Utility: sort unique inplace.
void get_stride(const clique &super, Index ind, std::vector< ad_plain > &offset, Index &stride)
Determine array offsets and stride of this clique.
Construct ad_plain from index.
ad_aug copy() const
Deep copy existing ad_aug. Result will be last value of the current tape.
std::vector< I > which(const std::vector< bool > &x)
Convert logical vector to index vector.
void reorder_random()
Re-order random effects.
Empty operator with inputs and outputs.
Copy value and set derivative to zero.
void update(Index i)
Integrate independent variable number i, (inv_index[i]).
virtual void * identifier()=0
Operator identifier.
multivariate_index(size_t bound_, size_t dim, bool flag=true)
CTOR of multivariate_index representing {0,..,B-1}^D
size_t output_size() const
Get info.
Type * dx_ptr(Index j)
pointer version - use with caution.
bool test(op_flag f) const
Test if a given flag is set.
std::vector< ad_aug > logsum
Log-probabilites of this clique.
std::vector< Index > remap_identical_sub_expressions(global &glob, std::vector< Index > inv_remap)
Remap identical sub-expressions.
void search(std::vector< Index > &start, bool sort_input=true, bool sort_output=true)
Find sub graph.
Struct defining the main AD context.
Scalar & deriv_inv(Index i)
Reference to i'th component of the gradient.
Scalar & deriv_dep(Index i)
Reference to i'th 'range direction' used to seed the derivative.
ad_aug & operator*=(const ad_aug &other)
Compound assignment taking advantage of operator* reductions.
void merge(Index i)
Merge all cliques that contain a given independent variable.
Split a computational graph using a simple heuristic.
void reverse_loop(ReverseArgs &args, size_t begin, const NodeFilter &node_filter) const
Generic reverse sweep.
ad_aug & operator/=(const ad_aug &other)
Compound assignment taking advantage of operator/ reductions.
std::vector< size_t > index()
Get multivariate_index as a vector.
void set_subgraph(const std::vector< bool > &marks, bool append=false)
Convert selected variables to a subgraph sequence.
size_t input_size() const
Get info.
void aggregate(global &glob, int sign=1)
Reduce a multivariate output function by summing its range components.
ad_plain taped_value
If taped_value is initialized (see ad_plain::initialize) this is the value of ad_aug.
ad_segment unpack(const ad_segment &x)
Unpack consecutive values on the tape.
Utilility class for sequential_reduction.
sequential_reduction(global &glob, std::vector< Index > random, std::vector< sr_grid > grid=std::vector< sr_grid >(1, sr_grid(-20, 20, 200)), std::vector< Index > random2grid=std::vector< Index >(0), bool perm=true)
CTOR of sequential reduction object.
Scalar Value() const
Return the underlying scalar value of this ad_aug.
matrix< Type > matmul(matrix< Type > x, matrix< Type > y)
Matrix multiply.
ad_segment pack(const ad_segment &x)
Pack consecutive values on the tape.
op_flag
Enumeration of selected boolean flags in global::Operator
virtual OperatorPure * self_fuse()=0
Lookup table for operator fusion. Merge this OperatorPure with an identical copy. If no match return ...
void push_back(OperatorPure *x)
Add new operator to this stack and update bitwise operator information.
std::vector< std::vector< Index > > dep_idx
Result: Pointers into original dependent variables.
virtual void deallocate()=0
Deallocate this OperatorPure.
std::vector< Index > dep2op
Used to lookup operator (node) of a dependent variable.
Forbid remappings if not consecutive.
Type dy(Index j) const
Partial derivative of end result wrt. j'th output variable of this operator.
Empty operator without inputs or outputs.
void Dependent()
Set this ad_aug as dependent.
void reverse(ReverseArgs< Scalar > &args)
Pack derivatives.
std::vector< bool > inv_marks()
Boolean representation of independent variable positions.
virtual Index input_size()=0
Number of inputs to this OperatorPure.
std::vector< size_t > order(std::vector< T > x)
Get permutation that sorts a vector.
Operator auto-completion.
Reference a variable on another tape.
size_t index(size_t i)
Get given component of this multivariate_index
bool all_allow_remap(const global &glob)
Test if all operators in the stack allow input remapping.
Index output(Index j) const
Get variable index of j'th output of current operator.
void forward_dense(std::vector< bool > &marks)
Full forward dependency sweep through the operation stack.
ad_aug & operator-=(const ad_aug &other)
Compound assignment taking advantage of operator- reductions.
void extract_sub_inplace(std::vector< bool > marks)
In-place subgraph extractor.
Scalar & value_inv(Index i)
Reference to i'th input value (parameter)
bool on_active_tape() const
Is this object on the current active tape?
bool reduce
Reduce returned hash values to one per dependent variable?
bool identicalOne() const
If true this variable is identical one. If false nothing can be concluded.
std::vector< global > vglob
Result: Vector of computational graphs.
ad_segment()
Construct empty object.
void flip()
Flip the mask of active dimensions.
bool strong_const
Include numerical value as part of hash code for constants? (see hash_sweep)
Type sum(Vector< Type > x)
std::vector< Scalar > derivs
Contiguous workspace for derivatives (same length as global::values)
void override_by(const ad_plain &x) const
Override this ad_plain and set glob to get_glob()
global * parent_glob
Previous ad context to be restored then this context ends.
Representation of a specific contiguous set of values on a specific tape.
size_t count()
Number of elements indexed by this multivariate_index
void clear_array_subgraph(Vector &array, typename Vector::value_type value=typename Vector::value_type(0)) const
Generic clear array along global::subgraph.
Configuration of hash_sweep.
void clear_deriv(Position start=Position(0, 0, 0))
Set derivatives to zero.
std::vector< Index > find_op_by_name(global &glob, const char *name)
Find nodes by name.
std::vector< hash_t > hash_sweep(hash_config cfg) const
Calculate hash codes of each dependent variable using a single forward sweep.
std::vector< period > split_period(global *glob, period p, size_t max_period_size)
Helper.
double value(T x)
Namespace with utility functions for adaptive numerical integration.
bool identical(const ad_aug &other) const
If true 'this' and 'other' are identical. If false nothing can be concluded (they might be equal...
IndexPair ptr
Input/output pointers.
void reverse_sub()
Reverse sweep along a subgraph.
bool ontape() const
Alias for on_some_tape() (for backward compatibility only)
std::vector< Index > var2op()
Build variable -> operator node lookup table using a single forward pass. The resulting sequence is m...
void Independent()
Set this ad_aug as independent.
std::vector< bool >::reference mask(size_t i)
Get / set active dimensions of this multivariate_index
Type & dx(Index j)
Partial derivative of end result wrt. j'th input variable of this operator.
std::vector< ad_aug > tabulate(std::vector< Index > inv_index, Index dep_index)
tabulate each combination of variables of a subgraph
Operator graph in compressed row storage.
void forward(ForwardArgs< Scalar > &args)
Pack values.
void forward(ForwardArgs< Scalar > &args)
Unpack values.
void addToTape() const
Force this variable to be put on the tape.
void extract()
Complete the parallel split.
Index input(Index j) const
Get variable index of j'th input to current operator.
bool identicalZero() const
If true this variable is identical zero. If false nothing can be concluded.
std::vector< std::vector< Index > > inv_idx
Result: Pointers into original independent variables.
void ad_start()
Enable ad calulations to be piped to this glob.
ad_aug operator+(const ad_aug &other) const
Addition.
Bitwise collection of selected operator flags.
Type * dy_ptr(Index j)
pointer version - use with caution.
void(* forward_compiled)(Scalar *)
Optional pointer to compiled code.
std::vector< bool > lmatch(const std::vector< T > &x, const std::vector< T > &y)
Match x vector in y vector.
Is it safe to remap the inputs of this operator?
std::vector< Index > dep_index
Pointers into global::values determining dependent variables.
Utilility class for sequential_reduction.
void eliminate()
Very simple tape optimizer.
bool in_context_stack(global *glob) const
Check if 'glob' exists in the active context stack.
bool constant() const
Is this object guarantied to be a constant?
Type max(const vector< Type > &x)