1 #ifndef HAVE_GLOBAL_HPP 2 #define HAVE_GLOBAL_HPP 22 typedef TMBAD_HASH_TYPE hash_t;
23 typedef TMBAD_INDEX_TYPE Index;
24 typedef TMBAD_SCALAR_TYPE Scalar;
25 typedef std::pair<Index, Index> IndexPair;
26 typedef TMBAD_INDEX_VECTOR IndexVector;
34 std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
36 size_t last = v.size() - 1;
37 for (
size_t i = 0; i < v.size(); ++i) {
39 if (i != last) out <<
", ";
48 struct ep : std::pair<T, bool> {
49 bool left()
const {
return !this->second; }
50 ep(T x,
bool type) : std::pair<T, bool>(x, type) {}
51 operator T() {
return this->first; }
54 typedef typename std::set<ep>::iterator iterator;
61 iterator it1 = x.upper_bound(x1);
62 iterator it2 = x.lower_bound(x2);
64 bool insert_x1 = (it1 == x.end()) || it1->left();
65 bool insert_x2 = (it2 == x.end()) || it2->left();
67 bool change = (it1 != it2) || insert_x1;
73 if (insert_x1) x.insert(x1);
74 if (insert_x2) x.insert(x2);
80 for (iterator it = x.begin(); it != x.end();) {
89 struct print_interval {
90 void operator()(T a, T b) { Rcout <<
"[ " << a <<
" , " << b <<
" ] "; }
99 struct Dependencies : std::vector<Index> {
100 typedef std::vector<Index> Base;
101 std::vector<std::pair<Index, Index> > I;
104 void add_interval(Index a, Index b);
105 void add_segment(Index start, Index size);
107 void monotone_transform_inplace(
const std::vector<Index> &x);
111 for (
size_t i = 0; i < this->size(); i++) f((*
this)[i]);
112 for (
size_t i = 0; i < I.size(); i++) {
113 for (Index j = I[i].first; j <= I[i].second; j++) {
122 for (
size_t i = 0; i < this->size(); i++) f((*
this)[i]);
123 for (
size_t i = 0; i < I.size(); i++) {
124 if (visited.
insert(I[i].first, I[i].second)) {
125 for (Index j = I[i].first; j <= I[i].second; j++) {
133 bool any(
const std::vector<bool> &x)
const;
138 enum ArrayAccess { x_read, y_read, y_write, dx_read, dx_write, dy_read };
139 template <
class Args, ArrayAccess What>
141 template <
class Args>
142 struct Accessor<Args, x_read> {
143 typename Args::value_type operator()(
const Args &args, Index j)
const {
147 template <
class Args>
148 struct Accessor<Args, y_read> {
149 typename Args::value_type operator()(
const Args &args, Index j)
const {
153 template <
class Args>
154 struct Accessor<Args, y_write> {
155 typename Args::value_type &operator()(Args &args, Index j) {
159 template <
class Args>
160 struct Accessor<Args, dx_read> {
161 typename Args::value_type operator()(
const Args &args, Index j)
const {
165 template <
class Args>
166 struct Accessor<Args, dx_write> {
167 typename Args::value_type &operator()(Args &args, Index j) {
171 template <
class Args>
172 struct Accessor<Args, dy_read> {
173 typename Args::value_type operator()(
const Args &args, Index j)
const {
185 const std::vector<T> &x;
186 const std::vector<Index> &i;
189 T operator[](
size_t j)
const {
return x[i[j]]; }
190 size_t size()
const {
return i.size(); }
191 operator std::vector<T>()
const {
192 std::vector<T> ans(i.size());
193 for (
size_t j = 0; j < ans.size(); j++) ans[j] = (*
this)[j];
205 template <
class Args, ArrayAccess What>
207 typedef typename Args::value_type Type;
208 Accessor<Args, What> element_access;
212 : args(args), from(from), n(n) {}
213 template <
class Other>
216 for (
size_t i = 0; i < n; i++) {
217 ans[i] = element_access(args, from + i);
221 Type operator[](Index i)
const {
return element_access(args, from + i); }
222 size_t size()
const {
return n; }
223 template <
class Other>
225 for (
size_t i = 0; i < n; i++) {
226 element_access(args, from + i) = other[i];
230 template <
class Other>
232 for (
size_t i = 0; i < n; i++) {
233 element_access(args, from + i) += other[i];
237 template <
class Other>
239 for (
size_t i = 0; i < n; i++) {
240 element_access(args, from + i) -= other[i];
255 template <
class dummy =
void>
265 Index
input(Index j)
const {
return inputs[ptr.first + j]; }
267 Index
output(Index j)
const {
return ptr.second + j; }
268 Args(
const IndexVector &inputs) : inputs(inputs.data()) {
278 template <
class Type>
280 typedef std::vector<Type> TypeVector;
281 typedef Type value_type;
285 Type x(Index j)
const {
return values[input(j)]; }
287 Type &y(Index j) {
return values[output(j)]; }
289 Type *x_ptr(Index j) {
return &values[input(j)]; }
291 Type *y_ptr(Index j) {
return &values[output(j)]; }
300 ForwardArgs(
const IndexVector &inputs, TypeVector &values,
302 :
Args<>(inputs), values(values.data()), glob_ptr(glob_ptr) {}
310 template <
class Type>
312 typedef std::vector<Type> TypeVector;
313 typedef Type value_type;
318 Type x(Index j)
const {
return values[input(j)]; }
320 Type y(Index j)
const {
return values[output(j)]; }
323 Type &dx(Index j) {
return derivs[input(j)]; }
326 Type dy(Index j)
const {
return derivs[output(j)]; }
328 Type *x_ptr(Index j) {
return &values[input(j)]; }
330 Type *y_ptr(Index j) {
return &values[output(j)]; }
332 Type *dx_ptr(Index j) {
return &derivs[input(j)]; }
334 Type *dy_ptr(Index j) {
return &derivs[output(j)]; }
351 ReverseArgs(
const IndexVector &inputs, TypeVector &values, TypeVector &derivs,
354 values(values.data()),
355 derivs(derivs.data()),
357 ptr.first = (Index)inputs.size();
358 ptr.second = (Index)values.size();
364 typedef std::vector<bool> BoolVector;
367 bool x(Index j) {
return values[input(j)]; }
368 BoolVector::reference y(Index j) {
return values[output(j)]; }
369 ForwardArgs(
const IndexVector &inputs, BoolVector &values,
371 :
Args<>(inputs), values(values), marked_intervals(marked_intervals) {}
373 template <
class Operator>
374 bool any_marked_input(
const Operator &op) {
375 if (Operator::implicit_dependencies) {
377 op.dependencies(*
this, dep);
378 return dep.any(values);
380 Index ninput = op.input_size();
381 for (Index j = 0; j < ninput; j++)
382 if (x(j))
return true;
387 template <
class Operator>
388 void mark_all_output(
const Operator &op) {
389 if (Operator::updating && op.output_size() == 0) {
391 op.dependencies_updating(*
this, dep);
393 for (
size_t i = 0; i < dep.size(); i++) values[dep[i]] =
true;
395 for (
size_t i = 0; i < dep.I.size(); i++) {
396 Index a = dep.I[i].first;
397 Index b = dep.I[i].second;
400 for (Index j = a; j <= b; j++) {
406 Index noutput = op.output_size();
407 for (Index j = 0; j < noutput; j++) y(j) =
true;
411 template <
class Operator>
412 bool mark_dense(
const Operator &op) {
413 if (any_marked_input(op)) {
423 typedef std::vector<bool> BoolVector;
426 BoolVector::reference x(Index j) {
return values[input(j)]; }
427 bool y(Index j) {
return values[output(j)]; }
428 ReverseArgs(IndexVector &inputs, BoolVector &values,
430 :
Args<>(inputs), values(values), marked_intervals(marked_intervals) {
431 ptr.first = (Index)inputs.size();
432 ptr.second = (Index)values.size();
435 template <
class Operator>
436 bool any_marked_output(
const Operator &op) {
437 if (Operator::elimination_protected)
return true;
438 if (Operator::updating && op.output_size() == 0) {
440 op.dependencies_updating(*
this, dep);
441 return dep.any(values);
443 Index noutput = op.output_size();
444 for (Index j = 0; j < noutput; j++)
445 if (y(j))
return true;
450 template <
class Operator>
451 void mark_all_input(
const Operator &op) {
452 if (Operator::implicit_dependencies) {
454 op.dependencies(*
this, dep);
456 for (
size_t i = 0; i < dep.size(); i++) values[dep[i]] =
true;
458 for (
size_t i = 0; i < dep.I.size(); i++) {
459 Index a = dep.I[i].first;
460 Index b = dep.I[i].second;
463 for (Index j = a; j <= b; j++) {
469 Index ninput = op.input_size();
470 for (Index j = 0; j < ninput; j++) x(j) =
true;
474 template <
class Operator>
475 bool mark_dense(
const Operator &op) {
476 if (any_marked_output(op)) {
484 std::string tostr(
const Index &x);
486 std::string tostr(
const Scalar &x);
488 struct Writer : std::string {
489 static std::ostream *cout;
490 Writer(std::string str);
495 std::string vinit(
const V &x) {
497 for (
size_t i = 0; i < x.size(); i++)
498 y = y + (i == 0 ?
"" :
",") + tostr(x[i]);
503 std::string p(std::string x);
504 Writer operator+(
const Writer &other);
505 Writer operator-(
const Writer &other);
508 Writer operator/(
const Writer &other);
510 Writer operator*(
const Scalar &other);
511 Writer operator+(
const Scalar &other);
513 void operator=(
const Writer &other);
514 void operator+=(
const Writer &other);
515 void operator-=(
const Writer &other);
516 void operator*=(
const Writer &other);
517 void operator/=(
const Writer &other);
520 friend Writer &operator<<(Writer &w,
const T &v) {
525 friend Writer &operator<<(Writer &w, const std::valarray<T> &x) {
533 typedef std::vector<Scalar> ScalarVector;
539 void set_indirect() {
544 Writer xd(Index j) {
return "v[" + tostr(input(j)) +
"]"; }
545 Writer yd(Index j) {
return "v[" + tostr(output(j)) +
"]"; }
546 Writer xi(Index j) {
return "v[i[" + tostr(Index(ptr.first + j)) +
"]]"; }
547 Writer yi(Index j) {
return "v[o[" + tostr(Index(ptr.second + j)) +
"]]"; }
548 Writer x(Index j) {
return (indirect ? xi(j) : xd(j)); }
549 Writer y(Index j) {
return (indirect ? yi(j) : yd(j)); }
550 Writer y_const(Index j) {
551 TMBAD_ASSERT2(!indirect,
"Attempt to write constants within loop?");
552 return tostr(Base::y(j));
554 ForwardArgs(IndexVector &inputs, ScalarVector &values)
556 const_literals =
false;
563 typedef std::vector<Scalar> ScalarVector;
568 void set_indirect() {
573 Writer dxd(Index j) {
return "d[" + tostr(input(j)) +
"]"; }
574 Writer dyd(Index j) {
return "d[" + tostr(output(j)) +
"]"; }
575 Writer xd(Index j) {
return "v[" + tostr(input(j)) +
"]"; }
576 Writer yd(Index j) {
return "v[" + tostr(output(j)) +
"]"; }
577 Writer dxi(Index j) {
return "d[i[" + tostr(Index(ptr.first + j)) +
"]]"; }
578 Writer dyi(Index j) {
return "d[o[" + tostr(Index(ptr.second + j)) +
"]]"; }
579 Writer xi(Index j) {
return "v[i[" + tostr(Index(ptr.first + j)) +
"]]"; }
580 Writer yi(Index j) {
return "v[o[" + tostr(Index(ptr.second + j)) +
"]]"; }
581 Writer x(Index j) {
return (indirect ? xi(j) : xd(j)); }
582 Writer y(Index j) {
return (indirect ? yi(j) : yd(j)); }
583 Writer dx(Index j) {
return (indirect ? dxi(j) : dxd(j)); }
584 Writer dy(Index j) {
return (indirect ? dyi(j) : dyd(j)); }
587 const_literals =
false;
589 ptr.first = (Index)inputs.size();
590 ptr.second = (Index)values.size();
595 Position(Index node, Index first, Index second);
599 bool operator<(
const Position &other)
const;
605 std::sort(x.begin(), x.end());
611 std::sort(x.begin(), x.end());
612 typename std::vector<T>::iterator last = std::unique(x.begin(), x.end());
613 x.erase(last, x.end());
618 std::vector<Index> j;
619 std::vector<Index> p;
621 size_t num_neighbors(Index node);
622 Index *neighbors(Index node);
634 std::vector<Index> rowcounts();
636 std::vector<Index> colcounts();
646 void bfs(
const std::vector<Index> &start, std::vector<bool> &visited,
647 std::vector<Index> &result);
660 void search(std::vector<Index> &start,
bool sort_input =
true,
661 bool sort_output =
true);
669 void search(std::vector<Index> &start, std::vector<bool> &visited,
670 bool sort_input =
true,
bool sort_output =
true);
676 std::vector<Index> boundary(
const std::vector<Index> &subgraph);
681 graph(
size_t num_nodes,
const std::vector<IndexPair> &edges);
685 template <
class CompleteOperator,
bool dynamic>
686 struct constructOperator {};
687 template <
class CompleteOperator>
688 struct constructOperator<CompleteOperator, false> {
689 CompleteOperator *operator()() {
690 static CompleteOperator *pOp =
new CompleteOperator();
694 template <
class CompleteOperator>
695 struct constructOperator<CompleteOperator, true> {
696 CompleteOperator *operator()() {
697 CompleteOperator *pOp =
new CompleteOperator();
702 CompleteOperator *operator()(
const T1 &x1) {
703 CompleteOperator *pOp =
new CompleteOperator(x1);
707 template <
class T1,
class T2>
708 CompleteOperator *operator()(
const T1 &x1,
const T2 &x2) {
709 CompleteOperator *pOp =
new CompleteOperator(x1, x2);
713 template <
class T1,
class T2,
class T3>
714 CompleteOperator *operator()(
const T1 &x1,
const T2 &x2,
const T3 &x3) {
715 CompleteOperator *pOp =
new CompleteOperator(x1, x2, x3);
719 template <
class T1,
class T2,
class T3,
class T4>
720 CompleteOperator *operator()(
const T1 &x1,
const T2 &x2,
const T3 &x3,
722 CompleteOperator *pOp =
new CompleteOperator(x1, x2, x3, x4);
761 IntRep get_flags(T op) {
764 (op.dynamic * (1 << dynamic)) |
765 (op.smart_pointer * (1 << smart_pointer)) |
766 (op.is_linear * (1 << is_linear)) |
767 (op.is_constant * (1 << is_constant)) |
768 (op.independent_variable * (1 << independent_variable)) |
769 (op.dependent_variable * (1 << dependent_variable)) |
770 (op.allow_remap * (1 << allow_remap)) |
771 (op.elimination_protected * (1 << elimination_protected)) |
772 (op.updating * (1 << updating));
778 op_info(T op) : code(get_flags(op)) {}
800 typedef TMBAD_REPLAY_TYPE Replay;
814 virtual void increment(IndexPair &ptr) = 0;
817 virtual void decrement(IndexPair &ptr) = 0;
827 virtual Index input_size() = 0;
829 virtual Index output_size() = 0;
859 virtual void dependencies(
Args<> &args, Dependencies &dep) = 0;
863 virtual void dependencies_updating(
Args<> &args, Dependencies &dep) = 0;
881 virtual const char *
op_name() {
return "NoName"; }
893 virtual void deallocate() = 0;
897 virtual void *operator_data() = 0;
902 virtual void *identifier() = 0;
907 virtual void *incomplete() = 0;
918 typedef std::vector<OperatorPure *> Base;
953 mutable std::vector<IndexPair> subgraph_ptr;
954 std::vector<Index> subgraph_seq;
956 void (*forward_compiled)(Scalar *);
958 void (*reverse_compiled)(Scalar *, Scalar *);
980 void shrink_to_fit(
double tol = .9);
985 void clear_deriv(Position start = Position(0, 0, 0));
988 Scalar &value_inv(Index i);
990 Scalar &deriv_inv(Index i);
992 Scalar &value_dep(Index i);
994 Scalar &deriv_dep(Index i);
1002 CONSTEXPR
bool operator[](
size_t i)
const;
1009 template <
class ForwardArgs,
class NodeFilter>
1011 const NodeFilter &node_filter)
const {
1012 for (
size_t i = begin; i < opstack.size(); i++) {
1014 opstack[i]->forward_incr(args);
1016 opstack[i]->increment(args.
ptr);
1020 template <
class ForwardArgs>
1028 template <
class ReverseArgs,
class NodeFilter>
1030 const NodeFilter &node_filter)
const {
1031 for (
size_t i = opstack.size(); i > begin;) {
1034 opstack[i]->reverse_decr(args);
1036 opstack[i]->decrement(args.
ptr);
1040 template <
class ReverseArgs>
1045 template <
class ForwardArgs>
1047 subgraph_cache_ptr();
1048 for (
size_t j = 0; j < subgraph_seq.size(); j++) {
1049 Index i = subgraph_seq[j];
1050 args.
ptr = subgraph_ptr[i];
1051 opstack[i]->forward(args);
1055 template <
class ReverseArgs>
1057 subgraph_cache_ptr();
1058 for (
size_t j = subgraph_seq.size(); j > 0;) {
1060 Index i = subgraph_seq[j];
1061 args.
ptr = subgraph_ptr[i];
1062 opstack[i]->reverse(args);
1075 template <
class Vector>
1077 typename Vector::value_type
value =
1078 typename Vector::value_type(0))
const {
1079 if (array.size() != values.size()) {
1080 array.resize(values.size());
1081 std::fill(array.begin(), array.end(),
value);
1084 subgraph_cache_ptr();
1085 for (
size_t j = 0; j < subgraph_seq.size(); j++) {
1086 Index i = subgraph_seq[j];
1087 size_t noutput = opstack[i]->output_size();
1088 for (
size_t k = 0; k < noutput; k++)
1089 array[subgraph_ptr[i].second + k] =
value;
1097 void forward(Position start = Position(0, 0, 0));
1105 void reverse(Position start = Position(0, 0, 0));
1112 void forward(std::vector<bool> &marks);
1114 void reverse(std::vector<bool> &marks);
1119 void forward_sub(std::vector<bool> &marks,
1120 const std::vector<bool> &node_filter = std::vector<bool>());
1125 void reverse_sub(std::vector<bool> &marks,
1126 const std::vector<bool> &node_filter = std::vector<bool>());
1135 void forward_dense(std::vector<bool> &marks);
1143 std::vector<Replay> values;
1146 std::vector<Replay> derivs;
1154 Replay &value_inv(Index i);
1156 Replay &deriv_inv(Index i);
1158 Replay &value_dep(Index i);
1160 Replay &deriv_dep(Index i);
1189 void forward(
bool inv_tags =
true,
bool dep_tags =
true,
1190 Position start = Position(0, 0, 0),
1191 const std::vector<bool> &node_filter = std::vector<bool>());
1199 void reverse(
bool dep_tags =
true,
bool inv_tags =
false,
1200 Position start = Position(0, 0, 0),
1201 const std::vector<bool> &node_filter = std::vector<bool>());
1207 void clear_deriv_sub();
1214 void forward_replay(
bool inv_tags =
true,
bool dep_tags =
true);
1221 void subgraph_cache_ptr()
const;
1229 void set_subgraph(
const std::vector<bool> &marks,
bool append =
false);
1231 void mark_subgraph(std::vector<bool> &marks);
1233 void unmark_subgraph(std::vector<bool> &marks);
1235 void subgraph_trivial();
1241 void clear_deriv_sub();
1279 void extract_sub_inplace(std::vector<bool> marks);
1293 std::vector<Index> var2op();
1299 std::vector<bool> var2op(
const std::vector<bool> &values);
1301 std::vector<Index> op2var(
const std::vector<Index> &seq);
1303 std::vector<bool> op2var(
const std::vector<bool> &seq_mark);
1312 std::vector<Index> op2idx(
const std::vector<Index> &var_subset,
1313 Index NA = (Index)-1);
1315 std::vector<bool> mark_space(
size_t n,
const std::vector<Index> ind);
1317 std::vector<bool> inv_marks();
1319 std::vector<bool> dep_marks();
1321 std::vector<bool> subgraph_marks();
1323 struct append_edges {
1325 const std::vector<bool> &keep_var;
1326 std::vector<Index> &var2op;
1327 std::vector<IndexPair> &edges;
1329 std::vector<bool> op_marks;
1331 append_edges(
size_t &i,
size_t num_nodes,
const std::vector<bool> &keep_var,
1332 std::vector<Index> &var2op, std::vector<IndexPair> &edges);
1333 void operator()(Index dep_j);
1335 void start_iteration();
1337 void end_iteration();
1347 graph build_graph(
bool transpose,
const std::vector<bool> &keep_var);
1351 graph forward_graph(std::vector<bool> keep_var = std::vector<bool>(0));
1355 graph reverse_graph(std::vector<bool> keep_var = std::vector<bool>(0));
1361 bool identical(
const global &other)
const;
1366 static const size_t n =
1367 (
sizeof(T) /
sizeof(hash_t)) + (
sizeof(T) %
sizeof(hash_t) != 0);
1369 std::fill(buffer, buffer + n, 0);
1370 for (
size_t i = 0; i <
sizeof(x); i++)
1371 ((
char *)buffer)[i] = ((
char *)&x)[i];
1374 for (
size_t i = 0; i < n; i++) h = (A * h) ^ (B * buffer[i]);
1385 hash_t hash()
const;
1459 std::vector<hash_t> hash_sweep(
hash_config cfg)
const;
1461 std::vector<hash_t> hash_sweep(
bool weak =
true)
const;
1480 std::string prefix, mark;
1490 template <
int ninput_,
int noutput_ = 1>
1493 static const bool dynamic =
false;
1495 static const int ninput = ninput_;
1497 static const int noutput = noutput_;
1499 static const int independent_variable =
false;
1501 static const int dependent_variable =
false;
1503 static const bool have_input_size_output_size =
false;
1505 static const bool have_increment_decrement =
false;
1507 static const bool have_forward_reverse =
true;
1509 static const bool have_forward_incr_reverse_decr =
false;
1511 static const bool have_forward_mark_reverse_mark =
false;
1513 static const bool have_dependencies =
false;
1519 static const bool allow_remap =
true;
1530 static const bool implicit_dependencies =
false;
1532 static const bool add_static_identifier =
false;
1535 static const bool add_forward_replay_copy =
false;
1538 static const bool have_eval =
false;
1540 static const int max_fuse_depth = 2;
1542 static const bool is_linear =
false;
1544 static const bool is_constant =
false;
1546 static const bool smart_pointer =
false;
1548 static const bool elimination_protected =
false;
1574 static const bool updating =
false;
1589 template <
int ninput,
int noutput>
1592 static const bool dynamic =
true;
1594 static const int max_fuse_depth = 0;
1598 template <
int ninput>
1601 static const bool dynamic =
true;
1603 static const int max_fuse_depth = 0;
1606 template <
int noutput = 1>
1607 struct DynamicInputOperator :
Operator<-1, noutput> {
1609 static const bool dynamic =
true;
1611 static const int max_fuse_depth = 0;
1614 struct DynamicInputOutputOperator :
Operator<-1, -1> {
1616 static const bool dynamic =
true;
1618 static const int max_fuse_depth = 0;
1619 Index ninput_, noutput_;
1620 DynamicInputOutputOperator(Index ninput, Index noutput);
1621 Index input_size()
const;
1622 Index output_size()
const;
1623 static const bool have_input_size_output_size =
true;
1625 struct UniqueDynamicOperator :
Operator<-1, -1> {
1627 static const bool dynamic =
true;
1629 static const int max_fuse_depth = 0;
1631 static const bool smart_pointer =
false;
1634 static const bool have_input_size_output_size =
true;
1636 struct SharedDynamicOperator : UniqueDynamicOperator {
1638 static const bool smart_pointer =
true;
1643 template <
class OperatorBase>
1646 Index input_size()
const {
return this->ninput; }
1647 Index output_size()
const {
return this->noutput; }
1648 static const bool have_input_size_output_size =
true;
1653 template <
class OperatorBase>
1656 void increment(IndexPair &ptr) {
1657 ptr.first += this->input_size();
1658 ptr.second += this->output_size();
1660 void decrement(IndexPair &ptr) {
1661 ptr.first -= this->input_size();
1662 ptr.second -= this->output_size();
1664 static const bool have_increment_decrement =
true;
1670 template <
class OperatorBase>
1674 template <
class Type>
1677 OperatorBase::forward_incr(args_cpy);
1679 template <
class Type>
1682 OperatorBase::increment(args_cpy.
ptr);
1683 OperatorBase::reverse_decr(args_cpy);
1685 static const bool have_forward_reverse =
true;
1691 template <
class OperatorBase>
1695 template <
class Type>
1697 OperatorBase::forward(args);
1698 OperatorBase::increment(args.
ptr);
1701 template <
class Type>
1703 OperatorBase::decrement(args.
ptr);
1704 OperatorBase::reverse(args);
1706 static const bool have_forward_incr_reverse_decr =
true;
1711 template <
class OperatorBase>
1715 template <
class Type>
1717 OperatorBase::forward(args);
1719 template <
class Type>
1721 OperatorBase::reverse(args);
1726 static const bool have_forward_mark_reverse_mark =
true;
1731 template <
class OperatorBase>
1734 void dependencies(
Args<> &args, Dependencies &dep)
const {
1735 Index ninput_ = this->input_size();
1736 for (Index j = 0; j < ninput_; j++) dep.push_back(args.
input(j));
1738 static const bool have_dependencies =
true;
1743 template <
class OperatorBase,
int ninput>
1746 template <
class OperatorBase>
1749 template <
class Type>
1751 args.
y(0) = this->eval(args.
x(0));
1755 template <
class OperatorBase>
1758 template <
class Type>
1760 args.
y(0) = this->eval(args.
x(0), args.
x(1));
1765 template <
bool flag,
class dummy>
1769 size_t operator()()
const {
return 0; }
1771 template <
class dummy>
1775 void increment() { counter++; }
1776 void decrement() { counter--; }
1777 size_t operator()()
const {
return counter; }
1781 template <
bool flag,
class Yes,
class No>
1783 template <
class Yes,
class No>
1784 struct if_else<true, Yes, No> {
1787 template <
class Yes,
class No>
1788 struct if_else<false, Yes, No> {
1793 template <
class OperatorBase>
1795 static const bool test1 = !OperatorBase::have_eval;
1798 test1, OperatorBase,
1801 static const bool test2 = Result1::have_input_size_output_size;
1807 static const bool test3 = !Result2::have_dependencies;
1812 static const bool test4 = Result3::have_increment_decrement;
1818 static const bool test5 = Result4::have_forward_mark_reverse_mark;
1823 static const bool test6 = Result5::have_forward_reverse &&
1824 !Result5::have_forward_incr_reverse_decr;
1829 static const bool test7 = Result6::have_forward_incr_reverse_decr &&
1830 !Result6::have_forward_reverse;
1839 template <
class Operator1,
class Operator2>
1841 Operator1::noutput + Operator2::noutput> {
1845 static const int independent_variable =
1846 Operator1::independent_variable && Operator2::independent_variable;
1848 static const int dependent_variable =
1849 Operator1::dependent_variable && Operator2::dependent_variable;
1851 static const int max_fuse_depth =
1852 (Operator1::max_fuse_depth < Operator2::max_fuse_depth
1853 ? Operator1::max_fuse_depth - 1
1854 : Operator2::max_fuse_depth - 1);
1856 static const bool is_linear = Operator1::is_linear && Operator2::is_linear;
1857 template <
class Type>
1859 Op1.forward_incr(args);
1860 Op2.forward_incr(args);
1862 template <
class Type>
1864 Op2.reverse_decr(args);
1865 Op1.reverse_decr(args);
1868 static const bool have_forward_incr_reverse_decr =
true;
1870 static const bool have_forward_reverse =
false;
1871 const char *op_name() {
return "Fused"; }
1881 template <
class Operator1>
1885 static const int independent_variable = Operator1::independent_variable;
1887 static const int dependent_variable = Operator1::dependent_variable;
1889 static const bool is_linear = Operator1::is_linear;
1891 Rep(Index n) : n(n) {}
1892 Index input_size()
const {
return Operator1::ninput * n; }
1893 Index output_size()
const {
return Operator1::noutput * n; }
1895 static const bool have_input_size_output_size =
true;
1896 template <
class Type>
1898 for (
size_t i = 0; i < (size_t)n; i++) Op.forward_incr(args);
1900 template <
class Type>
1902 for (
size_t i = 0; i < (size_t)n; i++) Op.reverse_decr(args);
1905 static const bool have_forward_incr_reverse_decr =
true;
1907 static const bool have_forward_reverse =
false;
1914 TMBAD_ASSERT(
false);
1916 size_t k = Op.input_size();
1917 size_t start = inputs.size() - k * n;
1918 std::valarray<Index> increment(k);
1920 for (
size_t i = 0; i < (size_t)n - 1; i++) {
1921 std::valarray<Index> v1(&inputs[start + i * k], k);
1922 std::valarray<Index> v2(&inputs[start + (i + 1) * k], k);
1924 increment = v2 - v1;
1926 bool ok = (increment == (v2 - v1)).min();
1927 if (!ok)
return NULL;
1932 size_t reduction = (n - 1) * k;
1933 inputs.resize(inputs.size() - reduction);
1944 const char *op_name() {
return "Rep"; }
1956 template <
class Operator1>
1959 static const int independent_variable = Operator1::independent_variable;
1961 static const int dependent_variable = Operator1::dependent_variable;
1963 static const bool is_linear = Operator1::is_linear;
1967 std::valarray<Index> increment_pattern;
1968 RepCompress(Index n, std::valarray<Index> v) : n(n), increment_pattern(v) {}
1969 Index input_size()
const {
return Operator1::ninput; }
1970 Index output_size()
const {
return Operator1::noutput * n; }
1972 static const bool have_input_size_output_size =
true;
1974 template <
class Type>
1976 std::valarray<Index> inputs(input_size());
1977 for (
size_t i = 0; i < inputs.size(); i++) inputs[i] = args.
input(i);
1979 args_cpy.
inputs = &inputs[0];
1980 args_cpy.
ptr.first = 0;
1981 for (
size_t i = 0; i < (size_t)n; i++) {
1982 Op.forward(args_cpy);
1983 inputs += this->increment_pattern;
1984 args_cpy.
ptr.second += Op.output_size();
1988 template <
class Type>
1990 std::valarray<Index> inputs(input_size());
1991 for (
size_t i = 0; i < inputs.size(); i++) inputs[i] = args.
input(i);
1992 inputs += n * this->increment_pattern;
1994 args_cpy.
inputs = &inputs[0];
1995 args_cpy.
ptr.first = 0;
1996 args_cpy.
ptr.second += n * Op.output_size();
1997 for (
size_t i = 0; i < (size_t)n; i++) {
1998 inputs -= this->increment_pattern;
1999 args_cpy.
ptr.second -= Op.output_size();
2000 Op.reverse(args_cpy);
2005 std::valarray<Index> inputs(input_size());
2006 for (
size_t i = 0; i < inputs.size(); i++) inputs[i] = args.
input(i);
2007 for (
size_t i = 0; i < (size_t)n; i++) {
2008 dep.insert(dep.end(), &inputs[0], &inputs[0] + inputs.size());
2009 inputs += this->increment_pattern;
2012 static const bool have_dependencies =
true;
2014 std::valarray<Index> inputs(Op.input_size());
2015 for (
size_t i = 0; i < (size_t)Op.input_size(); i++)
2016 inputs[i] = args.
input(i);
2017 std::valarray<Index> outputs(Op.output_size());
2018 for (
size_t i = 0; i < (size_t)Op.output_size(); i++)
2019 outputs[i] = args.
output(i);
2021 int ninp = Op.input_size();
2022 int nout = Op.output_size();
2024 w <<
"for (int count = 0, " 2025 <<
"i[" << ninp <<
"]=" << inputs <<
", " 2026 <<
"di[" << ninp <<
"]=" << increment_pattern <<
", " 2027 <<
"o[" << nout <<
"]=" << outputs <<
"; " 2028 <<
"count < " << n <<
"; count++) {\n";
2032 args_cpy.set_indirect();
2033 Op.forward(args_cpy);
2037 w <<
"for (int k=0; k<" << ninp <<
"; k++) i[k] += di[k];\n";
2039 w <<
"for (int k=0; k<" << nout <<
"; k++) o[k] += " << nout <<
";\n";
2045 std::valarray<Index> inputs(Op.input_size());
2046 for (
size_t i = 0; i < (size_t)Op.input_size(); i++)
2047 inputs[i] = args.
input(i);
2048 inputs += n * increment_pattern;
2049 std::valarray<Index> outputs(Op.output_size());
2050 for (
size_t i = 0; i < (size_t)Op.output_size(); i++)
2051 outputs[i] = args.
output(i);
2052 outputs += n * Op.output_size();
2054 int ninp = Op.input_size();
2055 int nout = Op.output_size();
2057 w <<
"for (int count = 0, " 2058 <<
"i[" << ninp <<
"]=" << inputs <<
", " 2059 <<
"di[" << ninp <<
"]=" << increment_pattern <<
", " 2060 <<
"o[" << nout <<
"]=" << outputs <<
"; " 2061 <<
"count < " << n <<
"; count++) {\n";
2064 w <<
"for (int k=0; k<" << ninp <<
"; k++) i[k] -= di[k];\n";
2066 w <<
"for (int k=0; k<" << nout <<
"; k++) o[k] -= " << nout <<
";\n";
2070 args_cpy.set_indirect();
2071 Op.reverse(args_cpy);
2078 static const bool have_forward_incr_reverse_decr =
false;
2080 static const bool have_forward_reverse =
true;
2082 static const bool have_forward_mark_reverse_mark =
true;
2083 const char *op_name() {
return "CRep"; }
2085 struct operator_data_t {
2088 std::valarray<Index> ip;
2090 : Op(
get_glob()->getOperator<Operator1>()),
2092 ip(x.increment_pattern) {}
2094 bool operator==(
const operator_data_t &other) {
2095 return (Op == other.Op) && (ip.size() == other.ip.size()) &&
2096 ((ip - other.ip).min() == 0);
2099 void *operator_data() {
return new operator_data_t(*
this); }
2101 if (this->op_name() == other->
op_name()) {
2102 operator_data_t *p1 =
2103 static_cast<operator_data_t *
>(
self->operator_data());
2104 operator_data_t *p2 =
2106 bool match = (*p1 == *p2);
2107 int other_n = p2->n;
2112 size_t reduction = increment_pattern.size();
2113 inputs.resize(inputs.size() - reduction);
2128 template <
class OperatorBase>
2130 typename CPL<OperatorBase>::type Op;
2139 if (Op.add_forward_replay_copy)
2140 forward_replay_copy(args);
2146 if (Op.add_forward_replay_copy) {
2147 forward_replay_copy(args);
2148 increment(args.
ptr);
2150 Op.forward_incr(args);
2159 args.mark_dense(Op);
2160 Op.increment(args.
ptr);
2171 std::vector<ad_plain>
operator()(
const std::vector<ad_plain> &x) {
2172 TMBAD_ASSERT2(OperatorBase::dynamic,
2173 "Stack to heap copy only allowed for dynamic operators");
2175 TMBAD_ASSERT2(pOp->ref_count() == 0,
"Operator already on the heap");
2176 pOp->ref_count.increment();
2180 TMBAD_ASSERT2(OperatorBase::dynamic,
2181 "Stack to heap copy only allowed for dynamic operators");
2183 TMBAD_ASSERT2(pOp->ref_count() == 0,
"Operator already on the heap");
2184 pOp->ref_count.increment();
2188 TMBAD_ASSERT2(OperatorBase::dynamic,
2189 "Stack to heap copy only allowed for dynamic operators");
2191 TMBAD_ASSERT2(pOp->ref_count() == 0,
"Operator already on the heap");
2192 pOp->ref_count.increment();
2196 std::vector<T> operator()(
const std::vector<T> &x) {
2197 std::vector<ad_plain> x_(x.begin(), x.end());
2198 std::vector<ad_plain> y_ = (*this)(x_);
2199 std::vector<T> y(y_.begin(), y_.end());
2203 std::vector<ad_plain> x(Op.input_size());
2204 for (
size_t i = 0; i < x.size(); i++) x[i] = args.
x(i);
2205 std::vector<ad_plain> y =
2207 for (
size_t i = 0; i < y.size(); i++) args.
y(i) = y[i];
2210 Op.dependencies(args, dep);
2213 Op.dependencies_updating(args, dep);
2222 template <
class Operator_,
int depth>
2226 return get_glob()->template getOperator<type>(2);
2229 template <
class Operator_>
2230 struct SelfFuse<Operator_, 0> {
2234 return SelfFuse<OperatorBase, OperatorBase::max_fuse_depth>()();
2237 return Op.other_fuse(
this, other);
2241 if (Op.smart_pointer) {
2242 ref_count.increment();
2244 }
else if (Op.dynamic)
2250 if (!Op.dynamic)
return;
2251 if (Op.smart_pointer) {
2252 if (ref_count() > 1) {
2253 ref_count.decrement();
2264 if (Op.add_static_identifier) {
2265 static void *
id =
new char();
2268 return (
void *)
this;
2274 template <
class OperatorBase>
2276 return constructOperator<Complete<OperatorBase>, OperatorBase::dynamic>()();
2278 template <
class OperatorBase,
class T1>
2280 return constructOperator<Complete<OperatorBase>, OperatorBase::dynamic>()(
2283 template <
class OperatorBase,
class T1,
class T2>
2285 return constructOperator<Complete<OperatorBase>, OperatorBase::dynamic>()(
2288 template <
class OperatorBase,
class T1,
class T2,
class T3>
2290 const T3 &x3)
const {
2291 return constructOperator<Complete<OperatorBase>, OperatorBase::dynamic>()(
2294 template <
class OperatorBase,
class T1,
class T2,
class T3,
class T4>
2296 const T4 &x4)
const {
2297 return constructOperator<Complete<OperatorBase>, OperatorBase::dynamic>()(
2301 static const int independent_variable =
true;
2302 template <
class Type>
2304 template <
class Type>
2306 const char *op_name();
2310 static const bool is_linear =
true;
2311 static const int dependent_variable =
true;
2312 static const bool have_eval =
true;
2313 template <
class Type>
2314 Type eval(Type x0) {
2317 template <
class Type>
2319 args.
dx(0) += args.
dy(0);
2321 const char *op_name();
2325 static const bool is_linear =
true;
2326 static const bool is_constant =
true;
2327 template <
class Type>
2330 template <
class Type>
2332 const char *op_name();
2337 static const bool is_linear =
true;
2339 template <
class Type>
2341 template <
class Type>
2343 const char *op_name();
2356 static const bool add_forward_replay_copy =
true;
2358 template <
class Type>
2360 for (Index i = 0; i < Base::noutput; i++) args.
y(i) = Type(0);
2362 template <
class Type>
2364 const char *op_name();
2368 void operator()(Replay *x, Index n);
2373 const char *op_name();
2381 NullOp2(Index ninput, Index noutput);
2382 const char *op_name();
2409 static const bool dynamic =
true;
2419 template <
class Type>
2421 TMBAD_ASSERT2(
false,
2422 "Reverse mode updates are forbidden until all references " 2427 const char *op_name();
2441 void set_fuse(
bool flag);
2447 template <
class OperatorBase>
2450 ans.index = this->values.size();
2452 this->values.push_back(result);
2455 add_to_opstack(pOp);
2457 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(values.size()));
2461 template <
class OperatorBase>
2464 ans.index = this->values.size();
2466 this->values.push_back(OperatorBase().eval(x.Value()));
2468 this->inputs.push_back(x.index);
2471 add_to_opstack(pOp);
2473 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(values.size()));
2474 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(inputs.size()));
2478 template <
class OperatorBase>
2481 ans.index = this->values.size();
2483 this->values.push_back(OperatorBase().eval(x.Value(), y.Value()));
2485 this->inputs.push_back(x.index);
2486 this->inputs.push_back(y.index);
2489 add_to_opstack(pOp);
2491 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(values.size()));
2492 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(inputs.size()));
2495 template <
class OperatorBase>
2498 IndexPair ptr((Index)inputs.size(), (Index)values.size());
2500 this->
template getOperator<OperatorBase>(lhs, rhs);
2503 inputs.push_back(lhs.index());
2504 inputs.push_back(rhs.index());
2505 if (more.size() > 0) inputs.push_back(more.index());
2507 values.resize(values.size() + n);
2512 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(values.size()));
2513 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(inputs.size()));
2517 template <
class OperatorBase>
2521 OperatorBase::dynamic,
2522 "Unlikely that you want to use this method for static operators?");
2524 OperatorBase::ninput == 0 || OperatorBase::implicit_dependencies,
2525 "Operators with pointer inputs should always implement " 2526 "'implicit_dependencies'");
2528 IndexPair ptr((Index)inputs.size(), (Index)values.size());
2531 TMBAD_ASSERT((Index)(lhs.size() > 0) + (Index)(rhs.size() > 0) ==
2533 if (lhs.size() > 0) inputs.push_back(lhs.index());
2534 if (rhs.size() > 0) inputs.push_back(rhs.index());
2536 values.resize(values.size() + n);
2541 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(values.size()));
2542 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(inputs.size()));
2547 template <
class OperatorBase>
2549 const std::vector<ad_plain> &x) {
2550 IndexPair ptr((Index)inputs.size(), (Index)values.size());
2554 for (
size_t i = 0; i < m; i++) inputs.push_back(x[i].index);
2556 values.resize(values.size() + n);
2561 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(values.size()));
2562 TMBAD_ASSERT(!TMBAD_INDEX_OVERFLOW(inputs.size()));
2563 std::vector<ad_plain> out(n);
2564 for (
size_t i = 0; i < n; i++) out[i].index = ans.index() + i;
2570 static const Index NA = (Index)-1;
2571 bool initialized()
const;
2572 bool on_some_tape()
const;
2574 void addToTape()
const;
2580 void override_by(
const ad_plain &x)
const;
2595 static const bool have_eval =
true;
2596 template <
class Type>
2597 Type eval(Type x0) {
2600 Replay eval(Replay x0);
2601 template <
class Type>
2603 args.
dx(0) += args.
dy(0);
2605 const char *op_name();
2614 ad_plain copy()
const;
2626 static const bool have_dependencies =
true;
2627 static const bool have_eval =
true;
2629 template <
class Type>
2630 Type eval(Type x0) {
2633 Replay eval(Replay x0);
2635 template <
class Type>
2643 void dependencies(
Args<> &args, Dependencies &dep)
const;
2644 const char *op_name();
2649 ad_plain copy0()
const;
2651 template <
bool left_var,
bool right_var>
2652 struct AddOp_ : BinaryOperator {
2653 static const bool is_linear =
true;
2654 static const bool have_eval =
true;
2655 template <
class Type>
2656 Type eval(Type x0, Type x1) {
2659 template <
class Type>
2661 if (left_var) args.
dx(0) += args.
dy(0);
2662 if (right_var) args.
dx(1) += args.
dy(0);
2664 const char *op_name() {
return "AddOp"; }
2666 if (other ==
get_glob()->getOperator<MulOp>()) {
2672 typedef AddOp_<true, true> AddOp;
2673 ad_plain operator+(
const ad_plain &other)
const;
2675 template <
bool left_var,
bool right_var>
2676 struct SubOp_ : BinaryOperator {
2677 static const bool is_linear =
true;
2678 static const bool have_eval =
true;
2679 template <
class Type>
2680 Type eval(Type x0, Type x1) {
2683 template <
class Type>
2685 if (left_var) args.
dx(0) += args.
dy(0);
2686 if (right_var) args.
dx(1) -= args.
dy(0);
2688 const char *op_name() {
return "SubOp"; }
2690 typedef SubOp_<true, true> SubOp;
2691 ad_plain operator-(
const ad_plain &other)
const;
2693 template <
bool left_var,
bool right_var>
2694 struct MulOp_ : BinaryOperator {
2695 static const bool have_eval =
true;
2696 static const bool is_linear = !left_var || !right_var;
2697 template <
class Type>
2698 Type eval(Type x0, Type x1) {
2701 template <
class Type>
2703 if (left_var) args.
dx(0) += args.
x(1) * args.
dy(0);
2704 if (right_var) args.
dx(1) += args.
x(0) * args.
dy(0);
2706 const char *op_name() {
return "MulOp"; }
2708 typedef MulOp_<true, true> MulOp;
2709 ad_plain
operator*(
const ad_plain &other)
const;
2710 ad_plain operator*(
const Scalar &other)
const;
2712 template <
bool left_var,
bool right_var>
2713 struct DivOp_ : BinaryOperator {
2714 static const bool have_eval =
true;
2715 template <
class Type>
2716 Type eval(Type x0, Type x1) {
2719 template <
class Type>
2721 Type tmp0 = args.
dy(0) / args.
x(1);
2722 if (left_var) args.
dx(0) += tmp0;
2723 if (right_var) args.
dx(1) -= args.
y(0) * tmp0;
2725 const char *op_name() {
return "DivOp"; }
2727 typedef DivOp_<true, true> DivOp;
2728 ad_plain operator/(
const ad_plain &other)
const;
2730 struct NegOp : UnaryOperator {
2731 static const bool is_linear =
true;
2732 static const bool have_eval =
true;
2733 template <
class Type>
2734 Type eval(Type x0) {
2737 template <
class Type>
2739 args.
dx(0) -= args.
dy(0);
2741 const char *op_name();
2743 ad_plain operator-()
const;
2745 ad_plain &operator+=(
const ad_plain &other);
2746 ad_plain &operator-=(
const ad_plain &other);
2747 ad_plain &operator*=(
const ad_plain &other);
2748 ad_plain &operator/=(
const ad_plain &other);
2754 Scalar Value()
const;
2755 Scalar Value(
global *glob)
const;
2772 void Independent(std::vector<ad_plain> &x);
2798 ad_segment(Replay *x,
size_t n,
bool zero_check =
false);
2799 bool identicalZero();
2800 bool all_on_active_tape(Replay *x,
size_t n);
2801 bool is_contiguous(Replay *x,
size_t n);
2802 bool all_zero(Replay *x,
size_t n);
2803 bool all_constant(Replay *x,
size_t n);
2804 size_t size()
const;
2805 size_t rows()
const;
2806 size_t cols()
const;
2808 ad_plain operator[](
size_t i)
const;
2809 ad_plain offset()
const;
2810 Index index()
const;
2838 TMBAD_UNION_OR_STRUCT {
2844 bool on_some_tape()
const;
2846 bool on_active_tape()
const;
2848 bool ontape()
const;
2852 bool constant()
const;
2853 Index index()
const;
2861 Scalar Value()
const;
2876 void addToTape()
const;
2880 void override_by(
const ad_plain &x)
const;
2882 bool in_context_stack(
global *glob)
const;
2890 bool identicalZero()
const;
2893 bool identicalOne()
const;
2897 bool bothConstant(
const ad_aug &other)
const;
2901 bool identical(
const ad_aug &other)
const;
2914 ad_aug operator-()
const;
2946 void Independent(std::vector<ad_aug> &x);
2949 template <
class S,
class T>
2950 std::ostream &operator<<(std::ostream &os, const std::pair<S, T> &x) {
2951 os <<
"(" << x.first <<
", " << x.second <<
")";
2955 std::ostream &operator<<(std::ostream &os,
const global::ad_plain &x);
2956 std::ostream &operator<<(std::ostream &os,
const global::ad_aug &x);
2971 bool operator==(
const T &other)
const {
2972 return this->Value() == other.Value();
2974 bool operator!=(
const T &other)
const {
2975 return this->Value() != other.Value();
2977 bool operator>=(
const T &other)
const {
2978 return this->Value() >= other.Value();
2980 bool operator<=(
const T &other)
const {
2981 return this->Value() <= other.Value();
2983 bool operator<(
const T &other)
const {
return this->Value() < other.Value(); }
2984 bool operator>(
const T &other)
const {
return this->Value() > other.Value(); }
2986 adaptive operator+(
const T &other)
const {
2989 adaptive operator-(
const T &other)
const {
2995 adaptive operator/(
const T &other)
const {
3002 typedef global::ad_plain ad_plain;
3004 typedef global::Replay Replay;
3018 struct ad_aug_index : ad_aug {
3019 ad_aug_index(
const Index &i);
3020 ad_aug_index(
const ad_aug &x);
3021 ad_aug_index(
const ad_plain &x);
3025 void Independent(std::vector<T> &x) {
3026 for (
size_t i = 0; i < x.size(); i++) x[i].Independent();
3029 void Dependent(std::vector<T> &x) {
3030 for (
size_t i = 0; i < x.size(); i++) x[i].Dependent();
3036 Scalar Value(Scalar x);
3048 for (
size_t i = 0; i < (size_t)x.size(); i++) {
3049 if (!x[i].on_some_tape()) {
3053 Index j = ad_plain(x[i]).index;
3055 if (j != j_previous + 1) {
3073 for (
size_t i = 0; i < (size_t)x.size(); i++) y[i] = x[i].copy();
3086 ad_aug operator+(
const double &x,
const ad_aug &y);
3087 ad_aug operator-(
const double &x,
const ad_aug &y);
3088 ad_aug
operator*(
const double &x,
const ad_aug &y);
3089 ad_aug operator/(
const double &x,
const ad_aug &y);
3091 bool operator<(
const double &x,
const ad_adapt &y);
3092 bool operator<=(
const double &x,
const ad_adapt &y);
3093 bool operator>(
const double &x,
const ad_adapt &y);
3094 bool operator>=(
const double &x,
const ad_adapt &y);
3095 bool operator==(
const double &x,
const ad_adapt &y);
3096 bool operator!=(
const double &x,
const ad_adapt &y);
3101 Writer floor(
const Writer &x);
3103 static const bool have_eval =
true;
3104 template <
class Type>
3108 template <
class Type>
3110 const char *op_name();
3112 ad_plain floor(
const ad_plain &x);
3113 ad_aug floor(
const ad_aug &x);
3114 Writer ceil(
const Writer &x);
3116 static const bool have_eval =
true;
3117 template <
class Type>
3121 template <
class Type>
3123 const char *op_name();
3125 ad_plain ceil(
const ad_plain &x);
3126 ad_aug ceil(
const ad_aug &x);
3127 Writer trunc(
const Writer &x);
3129 static const bool have_eval =
true;
3130 template <
class Type>
3134 template <
class Type>
3136 const char *op_name();
3138 ad_plain trunc(
const ad_plain &x);
3139 ad_aug trunc(
const ad_aug &x);
3140 Writer round(
const Writer &x);
3142 static const bool have_eval =
true;
3143 template <
class Type>
3147 template <
class Type>
3149 const char *op_name();
3151 ad_plain round(
const ad_plain &x);
3152 ad_aug round(
const ad_aug &x);
3154 double sign(
const double &x);
3155 Writer sign(
const Writer &x);
3157 static const bool have_eval =
true;
3158 template <
class Type>
3162 template <
class Type>
3164 const char *op_name();
3166 ad_plain sign(
const ad_plain &x);
3167 ad_aug sign(
const ad_aug &x);
3169 double ge0(
const double &x);
3170 double lt0(
const double &x);
3171 Writer ge0(
const Writer &x);
3173 static const bool have_eval =
true;
3174 template <
class Type>
3178 template <
class Type>
3180 const char *op_name();
3182 ad_plain ge0(
const ad_plain &x);
3183 ad_aug ge0(
const ad_aug &x);
3184 Writer lt0(
const Writer &x);
3186 static const bool have_eval =
true;
3187 template <
class Type>
3191 template <
class Type>
3193 const char *op_name();
3195 ad_plain lt0(
const ad_plain &x);
3196 ad_aug lt0(
const ad_aug &x);
3216 Writer fabs(
const Writer &x);
3218 static const bool have_eval =
true;
3219 template <
class Type>
3223 template <
class Type>
3225 args.
dx(0) += args.
dy(0) * sign(args.
x(0));
3228 const char *op_name();
3230 ad_plain fabs(
const ad_plain &x);
3231 ad_aug fabs(
const ad_aug &x);
3232 ad_adapt fabs(
const ad_adapt &x);
3233 Writer cos(
const Writer &x);
3234 ad_aug cos(
const ad_aug &x);
3235 Writer sin(
const Writer &x);
3237 static const bool have_eval =
true;
3238 template <
class Type>
3242 template <
class Type>
3244 args.
dx(0) += args.
dy(0) * cos(args.
x(0));
3247 const char *op_name();
3249 ad_plain sin(
const ad_plain &x);
3250 ad_aug sin(
const ad_aug &x);
3251 ad_adapt sin(
const ad_adapt &x);
3252 Writer cos(
const Writer &x);
3254 static const bool have_eval =
true;
3255 template <
class Type>
3259 template <
class Type>
3261 args.
dx(0) += args.
dy(0) * -sin(args.
x(0));
3264 const char *op_name();
3266 ad_plain cos(
const ad_plain &x);
3267 ad_aug cos(
const ad_aug &x);
3268 ad_adapt cos(
const ad_adapt &x);
3269 Writer exp(
const Writer &x);
3271 static const bool have_eval =
true;
3272 template <
class Type>
3276 template <
class Type>
3278 args.
dx(0) += args.
dy(0) * args.
y(0);
3281 const char *op_name();
3283 ad_plain exp(
const ad_plain &x);
3284 ad_aug exp(
const ad_aug &x);
3285 ad_adapt exp(
const ad_adapt &x);
3286 Writer log(
const Writer &x);
3288 static const bool have_eval =
true;
3289 template <
class Type>
3293 template <
class Type>
3295 args.
dx(0) += args.
dy(0) * Type(1.) / args.
x(0);
3298 const char *op_name();
3300 ad_plain log(
const ad_plain &x);
3301 ad_aug log(
const ad_aug &x);
3302 ad_adapt log(
const ad_adapt &x);
3303 Writer sqrt(
const Writer &x);
3305 static const bool have_eval =
true;
3306 template <
class Type>
3310 template <
class Type>
3312 args.
dx(0) += args.
dy(0) * Type(0.5) / args.
y(0);
3315 const char *op_name();
3317 ad_plain sqrt(
const ad_plain &x);
3318 ad_aug sqrt(
const ad_aug &x);
3319 ad_adapt sqrt(
const ad_adapt &x);
3320 Writer tan(
const Writer &x);
3322 static const bool have_eval =
true;
3323 template <
class Type>
3327 template <
class Type>
3329 args.
dx(0) += args.
dy(0) * Type(1.) / (cos(args.
x(0)) * cos(args.
x(0)));
3332 const char *op_name();
3334 ad_plain tan(
const ad_plain &x);
3335 ad_aug tan(
const ad_aug &x);
3336 ad_adapt tan(
const ad_adapt &x);
3337 Writer cosh(
const Writer &x);
3338 ad_aug cosh(
const ad_aug &x);
3339 Writer sinh(
const Writer &x);
3341 static const bool have_eval =
true;
3342 template <
class Type>
3346 template <
class Type>
3348 args.
dx(0) += args.
dy(0) * cosh(args.
x(0));
3351 const char *op_name();
3353 ad_plain sinh(
const ad_plain &x);
3354 ad_aug sinh(
const ad_aug &x);
3355 ad_adapt sinh(
const ad_adapt &x);
3356 Writer cosh(
const Writer &x);
3358 static const bool have_eval =
true;
3359 template <
class Type>
3363 template <
class Type>
3365 args.
dx(0) += args.
dy(0) * sinh(args.
x(0));
3368 const char *op_name();
3370 ad_plain cosh(
const ad_plain &x);
3371 ad_aug cosh(
const ad_aug &x);
3372 ad_adapt cosh(
const ad_adapt &x);
3373 Writer tanh(
const Writer &x);
3375 static const bool have_eval =
true;
3376 template <
class Type>
3380 template <
class Type>
3382 args.
dx(0) += args.
dy(0) * Type(1.) / (cosh(args.
x(0)) * cosh(args.
x(0)));
3385 const char *op_name();
3387 ad_plain tanh(
const ad_plain &x);
3388 ad_aug tanh(
const ad_aug &x);
3389 ad_adapt tanh(
const ad_adapt &x);
3390 Writer expm1(
const Writer &x);
3392 static const bool have_eval =
true;
3393 template <
class Type>
3397 template <
class Type>
3399 args.
dx(0) += args.
dy(0) * args.
y(0) + Type(1.);
3402 const char *op_name();
3404 ad_plain expm1(
const ad_plain &x);
3405 ad_aug expm1(
const ad_aug &x);
3406 ad_adapt expm1(
const ad_adapt &x);
3407 Writer log1p(
const Writer &x);
3409 static const bool have_eval =
true;
3410 template <
class Type>
3414 template <
class Type>
3416 args.
dx(0) += args.
dy(0) * Type(1.) / (args.
x(0) + Type(1.));
3419 const char *op_name();
3421 ad_plain log1p(
const ad_plain &x);
3422 ad_aug log1p(
const ad_aug &x);
3423 ad_adapt log1p(
const ad_adapt &x);
3424 Writer asin(
const Writer &x);
3426 static const bool have_eval =
true;
3427 template <
class Type>
3431 template <
class Type>
3434 args.
dy(0) * Type(1.) / sqrt(Type(1.) - args.
x(0) * args.
x(0));
3437 const char *op_name();
3439 ad_plain asin(
const ad_plain &x);
3440 ad_aug asin(
const ad_aug &x);
3441 ad_adapt asin(
const ad_adapt &x);
3442 Writer acos(
const Writer &x);
3444 static const bool have_eval =
true;
3445 template <
class Type>
3449 template <
class Type>
3452 args.
dy(0) * Type(-1.) / sqrt(Type(1.) - args.
x(0) * args.
x(0));
3455 const char *op_name();
3457 ad_plain acos(
const ad_plain &x);
3458 ad_aug acos(
const ad_aug &x);
3459 ad_adapt acos(
const ad_adapt &x);
3460 Writer atan(
const Writer &x);
3462 static const bool have_eval =
true;
3463 template <
class Type>
3467 template <
class Type>
3469 args.
dx(0) += args.
dy(0) * Type(1.) / (Type(1.) + args.
x(0) * args.
x(0));
3472 const char *op_name();
3474 ad_plain atan(
const ad_plain &x);
3475 ad_aug atan(
const ad_aug &x);
3476 ad_adapt atan(
const ad_adapt &x);
3477 Writer asinh(
const Writer &x);
3479 static const bool have_eval =
true;
3480 template <
class Type>
3484 template <
class Type>
3487 args.
dy(0) * Type(1.) / sqrt(args.
x(0) * args.
x(0) + Type(1.));
3490 const char *op_name();
3492 ad_plain asinh(
const ad_plain &x);
3493 ad_aug asinh(
const ad_aug &x);
3494 ad_adapt asinh(
const ad_adapt &x);
3495 Writer acosh(
const Writer &x);
3497 static const bool have_eval =
true;
3498 template <
class Type>
3502 template <
class Type>
3505 args.
dy(0) * Type(1.) / sqrt(args.
x(0) * args.
x(0) - Type(1.));
3508 const char *op_name();
3510 ad_plain acosh(
const ad_plain &x);
3511 ad_aug acosh(
const ad_aug &x);
3512 ad_adapt acosh(
const ad_adapt &x);
3513 Writer atanh(
const Writer &x);
3515 static const bool have_eval =
true;
3516 template <
class Type>
3520 template <
class Type>
3522 args.
dx(0) += args.
dy(0) * Type(1.) / (Type(1) - args.
x(0) * args.
x(0));
3525 const char *op_name();
3527 ad_plain atanh(
const ad_plain &x);
3528 ad_aug atanh(
const ad_aug &x);
3529 ad_adapt atanh(
const ad_adapt &x);
3536 Writer pow(
const Writer &x1,
const Writer &x2);
3538 static const bool have_eval =
true;
3539 template <
class Type>
3540 Type eval(Type x1, Type x2) {
3543 template <
class Type>
3545 args.
dx(0) += args.
dy(0) * args.
x(1) * pow(args.
x(0), args.
x(1) - Type(1.));
3546 args.
dx(1) += args.
dy(0) * args.
y(0) * log(args.
x(0));
3548 const char *op_name();
3550 ad_plain pow(
const ad_plain &x1,
const ad_plain &x2);
3551 ad_aug pow(
const ad_aug &x1,
const ad_aug &x2);
3552 ad_adapt pow(
const ad_adapt &x1,
const ad_adapt &x2);
3554 Writer atan2(
const Writer &x1,
const Writer &x2);
3556 static const bool have_eval =
true;
3557 template <
class Type>
3558 Type eval(Type x1, Type x2) {
3559 return atan2(x1, x2);
3561 template <
class Type>
3563 args.
dx(0) += args.
dy(0) * args.
x(1) /
3564 (args.
x(0) * args.
x(0) + args.
x(1) * args.
x(1));
3565 args.
dx(1) += args.
dy(0) * -args.
x(0) /
3566 (args.
x(0) * args.
x(0) + args.
x(1) * args.
x(1));
3568 const char *op_name();
3570 ad_plain atan2(
const ad_plain &x1,
const ad_plain &x2);
3571 ad_aug atan2(
const ad_aug &x1,
const ad_aug &x2);
3572 ad_adapt atan2(
const ad_adapt &x1,
const ad_adapt &x2);
3574 Writer max(
const Writer &x1,
const Writer &x2);
3576 static const bool have_eval =
true;
3577 template <
class Type>
3578 Type eval(Type x1, Type x2) {
3581 template <
class Type>
3583 args.
dx(0) += args.
dy(0) * ge0(args.
x(0) - args.
x(1));
3584 args.
dx(1) += args.
dy(0) * lt0(args.
x(0) - args.
x(1));
3586 const char *op_name();
3588 ad_plain max(
const ad_plain &x1,
const ad_plain &x2);
3589 ad_aug max(
const ad_aug &x1,
const ad_aug &x2);
3590 ad_adapt max(
const ad_adapt &x1,
const ad_adapt &x2);
3593 Writer min(
const Writer &x1,
const Writer &x2);
3595 static const bool have_eval =
true;
3596 template <
class Type>
3597 Type eval(Type x1, Type x2) {
3600 template <
class Type>
3602 args.
dx(0) += args.
dy(0) * ge0(args.
x(1) - args.
x(0));
3603 args.
dx(1) += args.
dy(0) * lt0(args.
x(1) - args.
x(0));
3605 const char *op_name();
3607 ad_plain min(
const ad_plain &x1,
const ad_plain &x2);
3608 ad_aug min(
const ad_aug &x1,
const ad_aug &x2);
3609 ad_adapt min(
const ad_adapt &x1,
const ad_adapt &x2);
3610 Replay CondExpEq(
const Replay &x0,
const Replay &x1,
const Replay &x2,
3619 template <
class Type>
3621 TMBAD_ASSERT(
false);
3623 template <
class Type>
3625 TMBAD_ASSERT(
false);
3627 const char *op_name();
3629 Scalar CondExpEq(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3631 ad_plain CondExpEq(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3632 const ad_plain &x3);
3633 ad_aug CondExpEq(
const ad_aug &x0,
const ad_aug &x1,
const ad_aug &x2,
3635 Replay CondExpNe(
const Replay &x0,
const Replay &x1,
const Replay &x2,
3644 template <
class Type>
3646 TMBAD_ASSERT(
false);
3648 template <
class Type>
3650 TMBAD_ASSERT(
false);
3652 const char *op_name();
3654 Scalar CondExpNe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3656 ad_plain CondExpNe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3657 const ad_plain &x3);
3658 ad_aug CondExpNe(
const ad_aug &x0,
const ad_aug &x1,
const ad_aug &x2,
3660 Replay CondExpGt(
const Replay &x0,
const Replay &x1,
const Replay &x2,
3669 template <
class Type>
3671 TMBAD_ASSERT(
false);
3673 template <
class Type>
3675 TMBAD_ASSERT(
false);
3677 const char *op_name();
3679 Scalar CondExpGt(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3681 ad_plain CondExpGt(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3682 const ad_plain &x3);
3683 ad_aug CondExpGt(
const ad_aug &x0,
const ad_aug &x1,
const ad_aug &x2,
3685 Replay CondExpLt(
const Replay &x0,
const Replay &x1,
const Replay &x2,
3694 template <
class Type>
3696 TMBAD_ASSERT(
false);
3698 template <
class Type>
3700 TMBAD_ASSERT(
false);
3702 const char *op_name();
3704 Scalar CondExpLt(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3706 ad_plain CondExpLt(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3707 const ad_plain &x3);
3708 ad_aug CondExpLt(
const ad_aug &x0,
const ad_aug &x1,
const ad_aug &x2,
3710 Replay CondExpGe(
const Replay &x0,
const Replay &x1,
const Replay &x2,
3719 template <
class Type>
3721 TMBAD_ASSERT(
false);
3723 template <
class Type>
3725 TMBAD_ASSERT(
false);
3727 const char *op_name();
3729 Scalar CondExpGe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3731 ad_plain CondExpGe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3732 const ad_plain &x3);
3733 ad_aug CondExpGe(
const ad_aug &x0,
const ad_aug &x1,
const ad_aug &x2,
3735 Replay CondExpLe(
const Replay &x0,
const Replay &x1,
const Replay &x2,
3744 template <
class Type>
3746 TMBAD_ASSERT(
false);
3748 template <
class Type>
3750 TMBAD_ASSERT(
false);
3752 const char *op_name();
3754 Scalar CondExpLe(
const Scalar &x0,
const Scalar &x1,
const Scalar &x2,
3756 ad_plain CondExpLe(
const ad_plain &x0,
const ad_plain &x1,
const ad_plain &x2,
3757 const ad_plain &x3);
3758 ad_aug CondExpLe(
const ad_aug &x0,
const ad_aug &x1,
const ad_aug &x2,
3761 template <
class Info>
3765 InfoOp(Index n, Info info) : n(n), info(info) {}
3766 static const bool elimination_protected =
true;
3767 static const bool add_forward_replay_copy =
true;
3768 static const bool have_input_size_output_size =
true;
3769 template <
class Type>
3771 template <
class Type>
3773 Index input_size()
const {
return n; }
3774 Index output_size()
const {
return 0; }
3775 const char *op_name() {
return "InfoOp"; }
3777 Rcout << cfg.prefix << info << std::endl;
3779 void *operator_data() {
return &info; }
3781 template <
class Info>
3782 void addInfo(
const std::vector<ad_aug> &x,
const Info &info) {
3785 template <
class Info>
3786 void addInfo(
const std::vector<double> &x,
const Info &info) {}
3789 static const bool is_linear =
true;
3790 static const bool have_input_size_output_size =
true;
3791 static const bool add_forward_replay_copy =
true;
3793 Index input_size()
const;
3794 Index output_size()
const;
3796 template <
class Type>
3799 for (
size_t i = 0; i < n; i++) {
3800 args.
y(0) += args.
x(i);
3803 template <
class Type>
3805 for (
size_t i = 0; i < n; i++) {
3806 args.
dx(i) += args.
dy(0);
3809 const char *op_name();
3812 T
sum(
const std::vector<T> &x) {
3816 ad_plain logspace_sum(
const std::vector<ad_plain> &x);
3819 static const bool have_input_size_output_size =
true;
3820 Index input_size()
const;
3821 Index output_size()
const;
3822 LogSpaceSumOp(
size_t n);
3825 template <
class Type>
3827 for (
size_t i = 0; i < n; i++) {
3828 args.
dx(i) += exp(args.
x(i) - args.
y(0)) * args.
dy(0);
3831 const char *op_name();
3833 ad_plain logspace_sum(
const std::vector<ad_plain> &x);
3835 T logspace_sum(
const std::vector<T> &x_) {
3836 std::vector<ad_plain> x(x_.begin(), x_.end());
3837 return logspace_sum(x);
3840 ad_plain logspace_sum_stride(
const std::vector<ad_plain> &x,
3841 const std::vector<Index> &stride,
size_t n);
3843 std::vector<Index> stride;
3845 static const bool have_input_size_output_size =
true;
3847 Index number_of_terms()
const;
3848 template <
class Type>
3849 Type &entry(Type **px,
size_t i,
size_t j)
const {
3850 return px[j][0 + i * stride[j]];
3852 template <
class Type>
3853 Type rowsum(Type **px,
size_t i)
const {
3854 size_t m = stride.size();
3855 Type s = (Scalar)(0);
3856 for (
size_t j = 0; j < m; j++) {
3857 s += entry(px, i, j);
3861 Index input_size()
const;
3862 Index output_size()
const;
3863 LogSpaceSumStrideOp(std::vector<Index> stride,
size_t n);
3866 template <
class Type>
3868 size_t m = stride.size();
3869 std::vector<Type *> wrk1(m);
3870 std::vector<Type *> wrk2(m);
3871 Type **px = &(wrk1[0]);
3872 Type **pdx = &(wrk2[0]);
3873 for (
size_t i = 0; i < m; i++) {
3874 px[i] = args.
x_ptr(i);
3877 for (
size_t i = 0; i < n; i++) {
3878 Type s = rowsum(px, i);
3879 Type tmp = exp(s - args.
y(0)) * args.
dy(0);
3880 for (
size_t j = 0; j < m; j++) {
3881 entry(pdx, i, j) += tmp;
3889 void dependencies(
Args<> &args, Dependencies &dep)
const;
3891 static const bool have_dependencies =
true;
3893 static const bool implicit_dependencies =
true;
3895 static const bool allow_remap =
false;
3896 const char *op_name();
3901 ad_plain logspace_sum_stride(
const std::vector<ad_plain> &x,
3902 const std::vector<Index> &stride,
size_t n);
3904 T logspace_sum_stride(
const std::vector<T> &x_,
3905 const std::vector<Index> &stride,
size_t n) {
3906 std::vector<ad_plain> x(x_.begin(), x_.end());
3907 return logspace_sum_stride(x, stride, n);
3910 #endif // HAVE_GLOBAL_HPP Automatic differentiation library designed for TMB.
segment_ref< ReverseArgs, dx_write > dx_segment(Index from, Index size)
segment version
Add zero allocated workspace to the tape.
void reverse_decr(ReverseArgs< Writer > &args)
Source code writer.
virtual Index output_size()=0
Number of outputs from this OperatorPure.
Is this a linear operator ?
bool in_use
Is this glob present in the context stack?
Does this operator require dynamic allocation ?
Type y(Index j) const
j'th output variable of this operator
segment_ref< ReverseArgs, x_read > x_segment(Index from, Index size)
segment version
segment_ref< ForwardArgs, x_read > x_segment(Index from, Index size)
segment version
Index output_size()
Number of outputs from this OperatorPure.
void reverse(ReverseArgs< Scalar > &args)
Update input derivs of this OperatorPure.
IntRep code
Internal integer representation.
Operator with input/output dimension known at compile time.
ad_plain add_to_stack(const ad_plain &x, const ad_plain &y)
Add binary operator to the stack based on its two arguments
virtual void * operator_data()=0
Optional operator_data.
void * operator_data()
Return operator specific dynamic information (optional)
std::vector< Index > inv_seed
Optionally control seeding of InvOp in case strong_inv=true
operation_stack opstack
Operation stack.
void sort_inplace(std::vector< T > &x)
Utility: sort inplace.
void deallocate()
Deallocate this OperatorPure.
Replicate an operator and apply input compression.
Add default implementation of mandatory member: forward from optional member eval ...
const Index * inputs
Array for indirect access of operator inputs.
segment_ref< ForwardArgs, y_write > y_segment(Index from, Index size)
segment version
void dependencies_updating(Args<> &args, Dependencies &dep) const
Default implementation of OperatorPure::dependencies_updating()
IndexVector inputs
Pointers into global::values determining operator inputs.
global * get_glob()
Get pointer to current global AD context (or NULL if no context is active).
ad_plain add_to_stack(Scalar result=0)
Add nullary operator to the stack based on its result
void * incomplete()
Get pointer to operator before it was completed.
void reverse_loop_subgraph(ReverseArgs &args) const
Generic reverse sweep along global::subgraph_seq.
Access input/output values and derivatives during a reverse pass. Write access granted for the input ...
F & apply(F &f) const
Apply a functor to each interval.
void increment(IndexPair &ptr)
Increment input/output pointers to prepare for the next OperatorPure in the stack.
std::vector< Scalar > values
Contiguous workspace for taped variables (same length as global::derivs)
void dependencies_updating(Args<> &args, Dependencies &dep)
Get the indices of variables updated by this operator.
Is this a constant operator ?
void forward(ForwardArgs< Replay > &args)
Replay operation sequence.
Generate all mandatory members.
Contiguous set of variables on the current tape.
void dependencies(Args<> &args, Dependencies &dep) const
Calculate all inputs.
void reverse_loop(ReverseArgs &args, size_t begin=0) const
Generic reverse sweep.
Provide inplace read access to value or derivative arrays.
void forward_loop(ForwardArgs &args, size_t begin=0) const
Generic forward sweep.
Access input/output values during a forward pass. Write access granted for the output value only...
OperatorPure * compress()
Attempt to apply input compression to this Rep operator.
void reverse(ReverseArgs< Replay > &args)
Replay operation sequence.
void forceContiguous(V &x)
Make contiguous ad vector.
void forward(ForwardArgs< bool > &args)
Mark forward dependencies.
void * identifier()
Operator identifier.
if_else< test3, AddDependencies< Result2 >, Result2 >::type Result3
Add default implementation of mandatory member: dependencies
Type * x_ptr(Index j)
pointer version - use with caution.
void forward_loop(ForwardArgs &args, size_t begin, const NodeFilter &node_filter) const
Generic forward sweep.
Enable weak comparison operators of an ad type.
void forward_loop_subgraph(ForwardArgs &args) const
Generic forward sweep along global::subgraph_seq.
bool deterministic
Deterministic hash codes?
void forward_incr(ForwardArgs< Scalar > &args)
Fast equivalent of combined forward() and increment()
bool insert(T a, T b)
Insert new interval [a,b].
Operator that requires dynamic allocation. Compile time known input size.
void hash(hash_t &h, T x) const
Simple hash code of scalar.
vector< Type > operator*(matrix< Type > A, vector< Type > x)
void forward_incr(ForwardArgs< Writer > &args)
Source code writer.
Configuration of print method.
segment_ref< ReverseArgs, dy_read > dy_segment(Index from, Index size)
segment version
V getContiguous(const V &x)
Get contiguous (deep) copy of this vector.
std::vector< Index > inv2op
Used to lookup operator (node) of an independent variable.
The abstract operator for the operation stack global::opstack
Type min(const vector< Type > &x)
Type x(Index j) const
j'th input variable of this operator
void forward(ForwardArgs< Writer > &args)
Source code writer.
void * operator_data()
Optional operator_data.
if_else< test6, AddForwardIncrReverseDecr< Result5 >, Result5 >::type Result6
Add default implementation of mandatory members: forward_incr and reverse_decr from forward and rever...
bool strong_inv
Use unique code for each independent variable? (see hash_sweep)
void reverse_decr(ReverseArgs< Replay > &args)
Replay operation sequence.
virtual void forward(ForwardArgs< Scalar > &args)=0
Update output values of this OperatorPure.
std::vector< Index > inv_index
Pointers into global::values determining independent variables.
Type & y(Index j)
j'th output variable of this operator
bool strong_output
Use unique hash code for each output of an operator?
Substitute of std::vector<bool> with all elements true
if_else< test5, Result4, AddForwardMarkReverseMark< Result4 > >::type Result5
Add default implementation of mandatory members: forward_mark and reverse_mark
Type x(Index j) const
j'th input variable of this operator
op_info any
Bitwise max of operator flags in this stack.
void sort_unique_inplace(std::vector< T > &x)
Utility: sort unique inplace.
Construct ad_plain from index.
void reverse(ReverseArgs< Type > &args)
Reverse method applicable for Scalar and bool case.
virtual const char * op_name()
Name of this OperatorPure.
Protect this operator from elimination by the tape optimizer ?
Empty operator with inputs and outputs.
Copy value and set derivative to zero.
void decrement(IndexPair &ptr)
Decrement input/output pointers to prepare for the previous OperatorPure in the stack.
OperatorPure * self_fuse()
Lookup table for operator fusion. Merge this OperatorPure with an identical copy. If no match return ...
std::vector< size_t > match(const std::vector< T > &x, const std::vector< T > &y)
Match x vector in y vector.
Type * dx_ptr(Index j)
pointer version - use with caution.
const char * op_name()
Name of this OperatorPure.
Struct defining the main AD context.
Provide read/write access to an array segment.
Add default implementation of mandatory members: forward and reverse from forward_incr and reverse_de...
void reverse_loop(ReverseArgs &args, size_t begin, const NodeFilter &node_filter) const
Generic reverse sweep.
ad_plain taped_value
If taped_value is initialized (see ad_plain::initialize) this is the value of ad_aug.
void print(print_config cfg)
Print this operator (optional)
This operator may update existing variables ?
op_flag
Enumeration of selected boolean flags in global::Operator
Add default implementation of mandatory members: forward_mark and reverse_mark
void push_back(OperatorPure *x)
Add new operator to this stack and update bitwise operator information.
void reverse(ReverseArgs< bool > &args)
Mark reverse dependencies.
virtual void deallocate()=0
Deallocate this OperatorPure.
segment_ref< ReverseArgs, y_read > y_segment(Index from, Index size)
segment version
std::vector< Index > dep2op
Used to lookup operator (node) of a dependent variable.
op_info info()
Get operator info.
Type dy(Index j) const
Partial derivative of end result wrt. j'th output variable of this operator.
Utility for member completion.
Empty operator without inputs or outputs.
int IntRep
Type used for internal integer representation.
void reverse_decr(ReverseArgs< bool > &args)
Fast equivalent of combined decrement() and reverse()
virtual Index input_size()=0
Number of inputs to this OperatorPure.
Is output of this operator an independent variable ?
Operator auto-completion.
Reference a variable on another tape.
Add default implementation of mandatory members: increment and decrement
Operator that requires dynamic allocation. Compile time known input/output size.
OperatorPure * other_fuse(OperatorPure *other)
Lookup table for operator fusion. Merge this OperatorPure with another operator. If no match return N...
Index output(Index j) const
Get variable index of j'th output of current operator.
OperatorPure * copy()
Return a copy of this OperatorPure.
Is this operator a 'smart pointer' (with reference counting) ?
bool reduce
Reduce returned hash values to one per dependent variable?
OperatorPure * other_fuse(OperatorPure *self, OperatorPure *other)
How to fuse this operator (self) with another (other)
bool strong_const
Include numerical value as part of hash code for constants? (see hash_sweep)
if_else< test4, Result3, AddIncrementDecrement< Result3 > >::type Result4
Add default implementation of mandatory members: increment and decrement
Type sum(Vector< Type > x)
std::vector< bool > mark
Private workspace used by graph::search. Must either be empty or filled with false when not in use...
std::vector< Scalar > derivs
Contiguous workspace for derivatives (same length as global::values)
global * parent_glob
Previous ad context to be restored then this context ends.
void clear_array_subgraph(Vector &array, typename Vector::value_type value=typename Vector::value_type(0)) const
Generic clear array along global::subgraph.
Configuration of hash_sweep.
void forward_incr(ForwardArgs< bool > &args)
Fast equivalent of combined forward() and increment()
if_else< test7, AddForwardReverse< Result6 >, Result6 >::type Result7
Add default implementation of mandatory members: forward and reverse from forward_incr and reverse_de...
double value(T x)
Namespace with utility functions for adaptive numerical integration.
void print(print_config cfg)
Optional print method.
Index input_size()
Number of inputs to this OperatorPure.
bool isContiguous(V &x)
Is this ad vector available as a contiguous block on the tape?
IndexPair ptr
Input/output pointers.
void reverse(ReverseArgs< Type > &args)
Reverse mode updates are forbidden until all references are resolved.
Type & dx(Index j)
Partial derivative of end result wrt. j'th input variable of this operator.
Union of closed intervals.
Operator graph in compressed row storage.
Is output of this operator a dependent variable ?
void reverse(ReverseArgs< Type > &args)
Derivatives in the dense case are zero.
void forward(ForwardArgs< Scalar > &args)
Update output values of this OperatorPure.
Index input(Index j) const
Get variable index of j'th input to current operator.
void forward_incr_mark_dense(ForwardArgs< bool > &args)
Conditionally mark all outputs.
void forward_incr(ForwardArgs< Replay > &args)
Replay operation sequence.
Bitwise collection of selected operator flags.
ArrayAccess
Define segment_ref array to access inside ForwardArgs or ReverseArgs
if_else< test1, OperatorBase, AddForwardFromEval< OperatorBase, OperatorBase::ninput > >::type Result1
Add default implementation of mandatory member: forward from optional member eval ...
void reverse_decr(ReverseArgs< Scalar > &args)
Fast equivalent of combined decrement() and reverse()
Add default implementation of mandatory members: forward_incr and reverse_decr from forward and rever...
void reverse(ReverseArgs< Writer > &args)
Source code writer.
std::vector< ad_plain > add_to_stack(OperatorPure *pOp, const std::vector< ad_plain > &x)
Add vector operator to the stack based on its vector argument
Is it safe to remap the inputs of this operator?
void forward(ForwardArgs< Type > &args)
Forward method applicable for Scalar and bool case.
std::vector< Index > dep_index
Pointers into global::values determining dependent variables.
ad_plain add_to_stack(const ad_plain &x)
Add unary operator to the stack based on its argument
if_else< test2, Result1, AddInputSizeOutputSize< Result1 > >::type Result2
Add default implementation of mandatory members: input_size ans output_size
Add default implementation of mandatory member: dependencies
void dependencies(Args<> &args, Dependencies &dep)
Get the indices of variables required by this operator.
std::vector< ad_plain > operator()(const std::vector< ad_plain > &x)
Move a stack allocated instance to the heap and let the operation_stack manage the memory...
Type max(const vector< Type > &x)