TMB Documentation  v1.9.11
ad_blas.hpp
1 #ifndef HAVE_AD_BLAS_HPP
2 #define HAVE_AD_BLAS_HPP
3 // Autogenerated - do not edit by hand !
4 #include <Eigen/Dense>
5 #include "global.hpp"
6 
7 namespace TMBad {
8 
22 template <class Matrix>
24  bool yes = true;
25  Index j_previous = -1;
26  for (size_t i = 0; i < (size_t)x.size(); i++) {
27  if (!x(i).on_some_tape()) {
28  yes = false;
29  break;
30  }
31  Index j = ad_plain(x(i)).index;
32  if (i > 0) {
33  if (j != j_previous + 1) {
34  yes = false;
35  break;
36  }
37  }
38  j_previous = j;
39  }
40  if (yes) {
41  return global::ad_segment(ad_plain(x(0)), x.rows(), x.cols());
42  }
43 
44  ad_plain ans;
45  for (size_t i = 0; i < (size_t)x.size(); i++) {
46  ad_plain xi_cpy = x(i).copy();
47 
48  x(i).override_by(xi_cpy);
49  if (i == 0) ans = xi_cpy;
50  }
51  return global::ad_segment(ans, x.rows(), x.cols());
52 }
53 
54 using Eigen::Dynamic;
55 using Eigen::Map;
56 using Eigen::Matrix;
57 typedef Matrix<double, Dynamic, Dynamic> dmatrix;
58 typedef Matrix<global::Replay, Dynamic, Dynamic> vmatrix;
59 
60 template <class Target>
61 void fill(Target &y, const global::ad_segment x) {
62  TMBAD_ASSERT((size_t)y.size() == (size_t)x.size());
63  for (size_t i = 0; i < (size_t)y.size(); i++) {
64  y(i) = x[i];
65  }
66 }
67 
68 template <bool XT, bool YT, bool ZT, bool UP>
69 struct MatMul;
70 template <bool XT, bool YT, bool ZT, bool UP>
71 void matmul(const vmatrix &x, const vmatrix &y, Map<vmatrix> z) {
74  if (!UP) {
75  global::ad_segment out =
76  get_glob()->add_to_stack<MatMul<XT, YT, ZT, UP> >(xc, yc);
77  fill(z, out);
78  } else {
80  get_glob()->add_to_stack<MatMul<XT, YT, ZT, UP> >(xc, yc, zc);
81  }
82 }
83 
85 vmatrix matmul(const vmatrix &x, const vmatrix &y);
86 
88 dmatrix matmul(const dmatrix &x, const dmatrix &y);
89 
91 template <bool XT, bool YT, bool ZT, bool UP>
92 void matmul(Map<const dmatrix> x, Map<const dmatrix> y, Map<dmatrix> z) {
93  if (!UP) {
94  if (XT && YT && ZT) z.transpose() = x.transpose() * y.transpose();
95  if (!XT && YT && ZT) z.transpose() = x * y.transpose();
96  if (XT && !YT && ZT) z.transpose() = x.transpose() * y;
97  if (XT && YT && !ZT) z = x.transpose() * y.transpose();
98  if (!XT && !YT && ZT) z.transpose() = x * y;
99  if (XT && !YT && !ZT) z = x.transpose() * y;
100  if (!XT && YT && !ZT) z = x * y.transpose();
101  if (!XT && !YT && !ZT) z = x * y;
102  }
103  if (UP) {
104  if (XT && YT && ZT) z.transpose() += x.transpose() * y.transpose();
105  if (!XT && YT && ZT) z.transpose() += x * y.transpose();
106  if (XT && !YT && ZT) z.transpose() += x.transpose() * y;
107  if (XT && YT && !ZT) z += x.transpose() * y.transpose();
108  if (!XT && !YT && ZT) z.transpose() += x * y;
109  if (XT && !YT && !ZT) z += x.transpose() * y;
110  if (!XT && YT && !ZT) z += x * y.transpose();
111  if (!XT && !YT && !ZT) z += x * y;
112  }
113 }
114 
115 template <bool XT, bool YT, bool ZT, bool UP>
116 struct MatMul : global::Operator<2 + UP, -1> {
117  static const bool dynamic = true;
118  static const int max_fuse_depth = 0;
119  int n1, n2, n3;
120  static const int ninput = 2 + UP;
122  set_dim(X.rows(), X.cols(), Y.rows(), Y.cols());
123  }
124  MatMul(int n1, int n2, int n3) : n1(n1), n2(n2), n3(n3) {}
125  Index input_size() const { return 2 + UP; }
126  Index output_size() const {
127  if (UP) return 0;
128  int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
129  get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
130  return Zrows * Zcols;
131  }
132  static const bool have_input_size_output_size = true;
133  void set_dim(int Xrows, int Xcols, int Yrows, int Ycols) {
134  n1 = Xrows;
135  n2 = Xcols;
136  n3 = (YT ? Yrows : Ycols);
137  }
138  void get_dim(int &Xrows, int &Xcols, int &Yrows, int &Ycols, int &Zrows,
139  int &Zcols) const {
140  Xrows = n1;
141  Xcols = n2;
142 
143  int Xop_rows = Xrows, Xop_cols = Xcols;
144  if (XT) std::swap(Xop_rows, Xop_cols);
145 
146  int Yop_rows = Xop_cols, Yop_cols = n3;
147 
148  Yrows = Yop_rows;
149  Ycols = Yop_cols;
150  if (YT) std::swap(Yrows, Ycols);
151 
152  int Zop_rows = Xop_rows, Zop_cols = Yop_cols;
153 
154  Zrows = Zop_rows;
155  Zcols = Zop_cols;
156  if (ZT) std::swap(Zrows, Zcols);
157  }
158  template <class Type>
159  void forward(ForwardArgs<Type> &args) {
160  int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
161  get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
162  typedef Map<Matrix<Type, Dynamic, Dynamic> > MapMatrix;
163  typedef Map<const Matrix<Type, Dynamic, Dynamic> > ConstMapMatrix;
164  Type *zp = (UP ? args.x_ptr(2) : args.y_ptr(0));
165  ConstMapMatrix X(args.x_ptr(0), Xrows, Xcols);
166  ConstMapMatrix Y(args.x_ptr(1), Yrows, Ycols);
167  MapMatrix Z(zp, Zrows, Zcols);
168  matmul<XT, YT, ZT, UP>(X, Y, Z);
169  }
170  template <class Type>
171  void reverse(ReverseArgs<Type> &args) {
172  int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
173  get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
174  typedef Map<Matrix<Type, Dynamic, Dynamic> > MapMatrix;
175  typedef Map<const Matrix<Type, Dynamic, Dynamic> > ConstMapMatrix;
176  Type *dzp = (UP ? args.dx_ptr(2) : args.dy_ptr(0));
177  ConstMapMatrix X(args.x_ptr(0), Xrows, Xcols);
178  ConstMapMatrix Y(args.x_ptr(1), Yrows, Ycols);
179  ConstMapMatrix W(dzp, Zrows, Zcols);
180  MapMatrix DX(args.dx_ptr(0), Xrows, Xcols);
181  MapMatrix DY(args.dx_ptr(1), Yrows, Ycols);
182 
183  matmul<ZT, !YT, XT, true>(W, Y, DX);
184  matmul<!XT, ZT, YT, true>(X, W, DY);
185  }
186 
187  void dependencies(Args<> &args, Dependencies &dep) const {
188  int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
189  get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
190  dep.add_segment(args.input(0), Xrows * Xcols);
191  dep.add_segment(args.input(1), Yrows * Ycols);
192  }
193 
194  void dependencies_updating(Args<> &args, Dependencies &dep) const {
195  int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
196  get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
197  if (UP) {
198  dep.add_segment(args.input(2), Zrows * Zcols);
199  }
200  }
201  static const bool have_dependencies = true;
203  static const bool implicit_dependencies = true;
205  static const bool allow_remap = false;
207  static const bool updating = true;
208 
209  void forward(ForwardArgs<Writer> &args) { TMBAD_ASSERT(false); }
210  void reverse(ReverseArgs<Writer> &args) { TMBAD_ASSERT(false); }
211  const char *op_name() { return "MatMul"; }
212 };
213 
214 } // namespace TMBad
215 #endif // HAVE_AD_BLAS_HPP
Automatic differentiation library designed for TMB.
Definition: TMB.hpp:157
Operator with input/output dimension known at compile time.
Definition: global.hpp:1491
global * get_glob()
Get pointer to current global AD context (or NULL if no context is active).
Definition: TMBad.cpp:690
ad_plain add_to_stack(Scalar result=0)
Add nullary operator to the stack based on its result
Definition: global.hpp:2448
Access input/output values and derivatives during a reverse pass. Write access granted for the input ...
Definition: global.hpp:311
Contiguous set of variables on the current tape.
Definition: global.hpp:2780
Type * y_ptr(Index j)
pointer version - use with caution.
Definition: global.hpp:291
Access input/output values during a forward pass. Write access granted for the output value only...
Definition: global.hpp:279
Type * x_ptr(Index j)
pointer version - use with caution.
Definition: global.hpp:328
Type * x_ptr(Index j)
pointer version - use with caution.
Definition: global.hpp:289
Type * dx_ptr(Index j)
pointer version - use with caution.
Definition: global.hpp:332
matrix< Type > matmul(matrix< Type > x, matrix< Type > y)
Matrix multiply.
global::ad_segment contiguousBlock(const Matrix &x)
Request a contiguous block on the tape.
Definition: ad_blas.hpp:23
Index input(Index j) const
Get variable index of j&#39;th input to current operator.
Definition: global.hpp:265
Type * dy_ptr(Index j)
pointer version - use with caution.
Definition: global.hpp:334
License: GPL v2