1 #ifndef HAVE_AD_BLAS_HPP 2 #define HAVE_AD_BLAS_HPP 22 template <
class Matrix>
25 Index j_previous = -1;
26 for (
size_t i = 0; i < (size_t)x.size(); i++) {
27 if (!x(i).on_some_tape()) {
31 Index j = ad_plain(x(i)).index;
33 if (j != j_previous + 1) {
45 for (
size_t i = 0; i < (size_t)x.size(); i++) {
46 ad_plain xi_cpy = x(i).copy();
48 x(i).override_by(xi_cpy);
49 if (i == 0) ans = xi_cpy;
57 typedef Matrix<double, Dynamic, Dynamic> dmatrix;
58 typedef Matrix<global::Replay, Dynamic, Dynamic> vmatrix;
60 template <
class Target>
62 TMBAD_ASSERT((
size_t)y.size() == (size_t)x.size());
63 for (
size_t i = 0; i < (size_t)y.size(); i++) {
68 template <
bool XT,
bool YT,
bool ZT,
bool UP>
70 template <
bool XT,
bool YT,
bool ZT,
bool UP>
71 void matmul(
const vmatrix &x,
const vmatrix &y, Map<vmatrix> z) {
85 vmatrix
matmul(
const vmatrix &x,
const vmatrix &y);
88 dmatrix
matmul(
const dmatrix &x,
const dmatrix &y);
91 template <
bool XT,
bool YT,
bool ZT,
bool UP>
92 void matmul(Map<const dmatrix> x, Map<const dmatrix> y, Map<dmatrix> z) {
94 if (XT && YT && ZT) z.transpose() = x.transpose() * y.transpose();
95 if (!XT && YT && ZT) z.transpose() = x * y.transpose();
96 if (XT && !YT && ZT) z.transpose() = x.transpose() * y;
97 if (XT && YT && !ZT) z = x.transpose() * y.transpose();
98 if (!XT && !YT && ZT) z.transpose() = x * y;
99 if (XT && !YT && !ZT) z = x.transpose() * y;
100 if (!XT && YT && !ZT) z = x * y.transpose();
101 if (!XT && !YT && !ZT) z = x * y;
104 if (XT && YT && ZT) z.transpose() += x.transpose() * y.transpose();
105 if (!XT && YT && ZT) z.transpose() += x * y.transpose();
106 if (XT && !YT && ZT) z.transpose() += x.transpose() * y;
107 if (XT && YT && !ZT) z += x.transpose() * y.transpose();
108 if (!XT && !YT && ZT) z.transpose() += x * y;
109 if (XT && !YT && !ZT) z += x.transpose() * y;
110 if (!XT && YT && !ZT) z += x * y.transpose();
111 if (!XT && !YT && !ZT) z += x * y;
115 template <
bool XT,
bool YT,
bool ZT,
bool UP>
117 static const bool dynamic =
true;
118 static const int max_fuse_depth = 0;
120 static const int ninput = 2 + UP;
122 set_dim(X.rows(), X.cols(), Y.rows(), Y.cols());
124 MatMul(
int n1,
int n2,
int n3) : n1(n1), n2(n2), n3(n3) {}
125 Index input_size()
const {
return 2 + UP; }
126 Index output_size()
const {
128 int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
129 get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
130 return Zrows * Zcols;
132 static const bool have_input_size_output_size =
true;
133 void set_dim(
int Xrows,
int Xcols,
int Yrows,
int Ycols) {
136 n3 = (YT ? Yrows : Ycols);
138 void get_dim(
int &Xrows,
int &Xcols,
int &Yrows,
int &Ycols,
int &Zrows,
143 int Xop_rows = Xrows, Xop_cols = Xcols;
144 if (XT) std::swap(Xop_rows, Xop_cols);
146 int Yop_rows = Xop_cols, Yop_cols = n3;
150 if (YT) std::swap(Yrows, Ycols);
152 int Zop_rows = Xop_rows, Zop_cols = Yop_cols;
156 if (ZT) std::swap(Zrows, Zcols);
158 template <
class Type>
160 int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
161 get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
162 typedef Map<Matrix<Type, Dynamic, Dynamic> > MapMatrix;
163 typedef Map<const Matrix<Type, Dynamic, Dynamic> > ConstMapMatrix;
164 Type *zp = (UP ? args.
x_ptr(2) : args.
y_ptr(0));
165 ConstMapMatrix X(args.
x_ptr(0), Xrows, Xcols);
166 ConstMapMatrix Y(args.
x_ptr(1), Yrows, Ycols);
167 MapMatrix Z(zp, Zrows, Zcols);
168 matmul<XT, YT, ZT, UP>(X, Y, Z);
170 template <
class Type>
172 int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
173 get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
174 typedef Map<Matrix<Type, Dynamic, Dynamic> > MapMatrix;
175 typedef Map<const Matrix<Type, Dynamic, Dynamic> > ConstMapMatrix;
177 ConstMapMatrix X(args.
x_ptr(0), Xrows, Xcols);
178 ConstMapMatrix Y(args.
x_ptr(1), Yrows, Ycols);
179 ConstMapMatrix W(dzp, Zrows, Zcols);
180 MapMatrix DX(args.
dx_ptr(0), Xrows, Xcols);
181 MapMatrix DY(args.
dx_ptr(1), Yrows, Ycols);
183 matmul<ZT, !YT, XT, true>(W, Y, DX);
184 matmul<!XT, ZT, YT, true>(X, W, DY);
187 void dependencies(
Args<> &args, Dependencies &dep)
const {
188 int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
189 get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
190 dep.add_segment(args.
input(0), Xrows * Xcols);
191 dep.add_segment(args.
input(1), Yrows * Ycols);
194 void dependencies_updating(
Args<> &args, Dependencies &dep)
const {
195 int Xrows, Xcols, Yrows, Ycols, Zrows, Zcols;
196 get_dim(Xrows, Xcols, Yrows, Ycols, Zrows, Zcols);
198 dep.add_segment(args.
input(2), Zrows * Zcols);
201 static const bool have_dependencies =
true;
203 static const bool implicit_dependencies =
true;
205 static const bool allow_remap =
false;
207 static const bool updating =
true;
211 const char *op_name() {
return "MatMul"; }
215 #endif // HAVE_AD_BLAS_HPP Automatic differentiation library designed for TMB.
Operator with input/output dimension known at compile time.
global * get_glob()
Get pointer to current global AD context (or NULL if no context is active).
ad_plain add_to_stack(Scalar result=0)
Add nullary operator to the stack based on its result
Access input/output values and derivatives during a reverse pass. Write access granted for the input ...
Contiguous set of variables on the current tape.
Type * y_ptr(Index j)
pointer version - use with caution.
Access input/output values during a forward pass. Write access granted for the output value only...
Type * x_ptr(Index j)
pointer version - use with caution.
Type * x_ptr(Index j)
pointer version - use with caution.
Type * dx_ptr(Index j)
pointer version - use with caution.
matrix< Type > matmul(matrix< Type > x, matrix< Type > y)
Matrix multiply.
global::ad_segment contiguousBlock(const Matrix &x)
Request a contiguous block on the tape.
Index input(Index j) const
Get variable index of j'th input to current operator.
Type * dy_ptr(Index j)
pointer version - use with caution.