base_family.hpp
Go to the documentation of this file.
1 #ifndef STAN_VARIATIONAL_BASE_FAMILY_HPP
2 #define STAN_VARIATIONAL_BASE_FAMILY_HPP
3 
5 #include <stan/math/prim/mat.hpp>
6 #include <algorithm>
7 #include <ostream>
8 
9 namespace stan {
10  namespace variational {
11 
12  class base_family {
13  public:
14  // Constructors
16 
17  // Operations
18  base_family square() const;
19  base_family sqrt() const;
20 
21  // Compound assignment operators
22  base_family operator=(const base_family& rhs);
25  base_family operator+=(double scalar);
26  base_family operator*=(double scalar);
27 
28  // Distribution-based operations
29  const Eigen::VectorXd& mean() const;
30  double entropy() const;
31  Eigen::VectorXd transform(const Eigen::VectorXd& eta) const;
32  template <class BaseRNG>
33  void sample(BaseRNG& rng, Eigen::VectorXd& eta) const;
34  template <class M, class BaseRNG>
35  void calc_grad(base_family& elbo_grad,
36  M& m,
37  Eigen::VectorXd& cont_params,
38  int n_monte_carlo_grad,
39  BaseRNG& rng,
40  callbacks::logger& logger)
41  const;
42 
43  protected:
44  void write_error_msg_(std::ostream* error_msgs,
45  const std::exception& e) const {
46  if (!error_msgs) {
47  return;
48  }
49 
50  *error_msgs
51  << std::endl
52  << "Informational Message: The current gradient evaluation "
53  << "of the ELBO is ignored because of the following issue:"
54  << std::endl
55  << e.what() << std::endl
56  << "If this warning occurs often then your model may be "
57  << "either severely ill-conditioned or misspecified."
58  << std::endl;
59  }
60  };
61 
62  // Arithmetic operators
65  base_family operator+(double scalar, base_family rhs);
66  base_family operator*(double scalar, base_family rhs);
67  } // variational
68 } // stan
69 #endif
base_family operator/(base_family lhs, const base_family &rhs)
base_family operator*(double scalar, base_family rhs)
base_family operator/=(const base_family &rhs)
::xsd::cxx::tree::exception< char > exception
Definition: Database.h:225
void calc_grad(base_family &elbo_grad, M &m, Eigen::VectorXd &cont_params, int n_monte_carlo_grad, BaseRNG &rng, callbacks::logger &logger) const
base_family sqrt() const
Eigen::VectorXd transform(const Eigen::VectorXd &eta) const
base_family operator+(base_family lhs, const base_family &rhs)
void write_error_msg_(std::ostream *error_msgs, const std::exception &e) const
Definition: base_family.hpp:44
base_family operator+=(const base_family &rhs)
base_family operator*=(double scalar)
base_family square() const
base_family operator=(const base_family &rhs)
Float_t e
Definition: plot.C:35
void sample(BaseRNG &rng, Eigen::VectorXd &eta) const
const Eigen::VectorXd & mean() const