trace_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3 
5 #include <stan/math/rev/core.hpp>
8 #include <boost/utility/enable_if.hpp>
11 
12 namespace stan {
13 namespace math {
14 
15 namespace {
16 template <typename T2, int R2, int C2, typename T3, int R3, int C3>
17 class trace_inv_quad_form_ldlt_impl : public chainable_alloc {
18  protected:
19  inline void initializeB(const Eigen::Matrix<var, R3, C3> &B, bool haveD) {
20  Eigen::Matrix<double, R3, C3> Bd(B.rows(), B.cols());
21  variB_.resize(B.rows(), B.cols());
22  for (int j = 0; j < B.cols(); j++) {
23  for (int i = 0; i < B.rows(); i++) {
24  variB_(i, j) = B(i, j).vi_;
25  Bd(i, j) = B(i, j).val();
26  }
27  }
28  AinvB_ = ldlt_.solve(Bd);
29  if (haveD)
30  C_.noalias() = Bd.transpose() * AinvB_;
31  else
32  value_ = (Bd.transpose() * AinvB_).trace();
33  }
34  inline void initializeB(const Eigen::Matrix<double, R3, C3> &B, bool haveD) {
35  AinvB_ = ldlt_.solve(B);
36  if (haveD)
37  C_.noalias() = B.transpose() * AinvB_;
38  else
39  value_ = (B.transpose() * AinvB_).trace();
40  }
41 
42  template <int R1, int C1>
43  inline void initializeD(const Eigen::Matrix<var, R1, C1> &D) {
44  D_.resize(D.rows(), D.cols());
45  variD_.resize(D.rows(), D.cols());
46  for (int j = 0; j < D.cols(); j++) {
47  for (int i = 0; i < D.rows(); i++) {
48  variD_(i, j) = D(i, j).vi_;
49  D_(i, j) = D(i, j).val();
50  }
51  }
52  }
53  template <int R1, int C1>
54  inline void initializeD(const Eigen::Matrix<double, R1, C1> &D) {
55  D_ = D;
56  }
57 
58  public:
59  template <typename T1, int R1, int C1>
60  trace_inv_quad_form_ldlt_impl(const Eigen::Matrix<T1, R1, C1> &D,
61  const LDLT_factor<T2, R2, C2> &A,
62  const Eigen::Matrix<T3, R3, C3> &B)
63  : Dtype_(stan::is_var<T1>::value), ldlt_(A) {
64  initializeB(B, true);
65  initializeD(D);
66 
67  value_ = (D_ * C_).trace();
68  }
69 
70  trace_inv_quad_form_ldlt_impl(const LDLT_factor<T2, R2, C2> &A,
71  const Eigen::Matrix<T3, R3, C3> &B)
72  : Dtype_(2), ldlt_(A) {
73  initializeB(B, false);
74  }
75 
76  const int Dtype_; // 0 = double, 1 = var, 2 = missing
77  LDLT_factor<T2, R2, C2> ldlt_;
78  Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> D_;
79  Eigen::Matrix<vari *, Eigen::Dynamic, Eigen::Dynamic> variD_;
80  Eigen::Matrix<vari *, R3, C3> variB_;
81  Eigen::Matrix<double, R3, C3> AinvB_;
82  Eigen::Matrix<double, C3, C3> C_;
83  double value_;
84 };
85 
86 template <typename T2, int R2, int C2, typename T3, int R3, int C3>
87 class trace_inv_quad_form_ldlt_vari : public vari {
88  protected:
89  static inline void chainA(
90  double adj,
91  trace_inv_quad_form_ldlt_impl<double, R2, C2, T3, R3, C3> *impl) {}
92  static inline void chainB(
93  double adj,
94  trace_inv_quad_form_ldlt_impl<T2, R2, C2, double, R3, C3> *impl) {}
95 
96  static inline void chainA(
97  double adj,
98  trace_inv_quad_form_ldlt_impl<var, R2, C2, T3, R3, C3> *impl) {
99  Eigen::Matrix<double, R2, C2> aA;
100 
101  if (impl->Dtype_ != 2)
102  aA.noalias()
103  = -adj
104  * (impl->AinvB_ * impl->D_.transpose() * impl->AinvB_.transpose());
105  else
106  aA.noalias() = -adj * (impl->AinvB_ * impl->AinvB_.transpose());
107 
108  for (int j = 0; j < aA.cols(); j++)
109  for (int i = 0; i < aA.rows(); i++)
110  impl->ldlt_.alloc_->variA_(i, j)->adj_ += aA(i, j);
111  }
112  static inline void chainB(
113  double adj,
114  trace_inv_quad_form_ldlt_impl<T2, R2, C2, var, R3, C3> *impl) {
115  Eigen::Matrix<double, R3, C3> aB;
116 
117  if (impl->Dtype_ != 2)
118  aB.noalias() = adj * impl->AinvB_ * (impl->D_ + impl->D_.transpose());
119  else
120  aB.noalias() = 2 * adj * impl->AinvB_;
121 
122  for (int j = 0; j < aB.cols(); j++)
123  for (int i = 0; i < aB.rows(); i++)
124  impl->variB_(i, j)->adj_ += aB(i, j);
125  }
126 
127  public:
128  explicit trace_inv_quad_form_ldlt_vari(
129  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl)
130  : vari(impl->value_), impl_(impl) {}
131 
132  virtual void chain() {
133  // F = trace(D * B' * inv(A) * B)
134  // aA = -aF * inv(A') * B * D' * B' * inv(A')
135  // aB = aF*(inv(A) * B * D + inv(A') * B * D')
136  // aD = aF*(B' * inv(A) * B)
137  chainA(adj_, impl_);
138 
139  chainB(adj_, impl_);
140 
141  if (impl_->Dtype_ == 1) {
142  for (int j = 0; j < impl_->variD_.cols(); j++)
143  for (int i = 0; i < impl_->variD_.rows(); i++)
144  impl_->variD_(i, j)->adj_ += adj_ * impl_->C_(i, j);
145  }
146  }
147 
148  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl_;
149 };
150 
151 } // namespace
152 
153 /**
154  * Compute the trace of an inverse quadratic form. I.E., this computes
155  * trace(B^T A^-1 B)
156  * where the LDLT_factor of A is provided.
157  **/
158 template <typename T2, int R2, int C2, typename T3, int R3, int C3>
159 inline typename boost::enable_if_c<
162  const Eigen::Matrix<T3, R3, C3> &B) {
163  check_multiplicable("trace_inv_quad_form_ldlt", "A", A, "B", B);
164 
165  trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl_
166  = new trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3>(A, B);
167 
168  return var(new trace_inv_quad_form_ldlt_vari<T2, R2, C2, T3, R3, C3>(impl_));
169 }
170 
171 } // namespace math
172 } // namespace stan
173 #endif
const int Dtype_
Eigen::Matrix< double, C3, C3 > C_
Eigen::Matrix< vari *, R3, C3 > variB_
double value_
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > D_
const XML_Char int const XML_Char * value
Definition: expat.h:331
chain
Check that an output directory exists.
const double j
Definition: BetheBloch.cxx:29
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
static const double A
Definition: Units.h:82
LDLT_factor< T2, R2, C2 > ldlt_
Eigen::Matrix< vari *, Eigen::Dynamic, Eigen::Dynamic > variD_
Eigen::Matrix< double, R3, C3 > AinvB_
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Definition: trace.hpp:19
boost::enable_if_c< !stan::is_var< T1 >::value &&!stan::is_var< T2 >::value, typename boost::math::tools::promote_args< T1, T2 >::type >::type trace_inv_quad_form_ldlt(const LDLT_factor< T1, R2, C2 > &A, const Eigen::Matrix< T2, R3, C3 > &B)
trace_inv_quad_form_ldlt_impl< T2, R2, C2, T3, R3, C3 > * impl_