base_nuts.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
9 #include <algorithm>
10 #include <cmath>
11 #include <limits>
12 #include <string>
13 #include <vector>
14 
15 namespace stan {
16  namespace mcmc {
17  /**
18  * The No-U-Turn sampler (NUTS) with multinomial sampling
19  */
20  template <class Model, template<class, class> class Hamiltonian,
21  template<class> class Integrator, class BaseRNG>
22  class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
23  public:
24  base_nuts(const Model& model, BaseRNG& rng)
25  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
26  depth_(0), max_depth_(5), max_deltaH_(1000),
27  n_leapfrog_(0), divergent_(false), energy_(0) {
28  }
29 
30  /**
31  * specialized constructor for specified diag mass matrix
32  */
33  base_nuts(const Model& model, BaseRNG& rng,
34  Eigen::VectorXd& inv_e_metric)
35  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng,
36  inv_e_metric),
37  depth_(0), max_depth_(5), max_deltaH_(1000),
38  n_leapfrog_(0), divergent_(false), energy_(0) {
39  }
40 
41  /**
42  * specialized constructor for specified dense mass matrix
43  */
44  base_nuts(const Model& model, BaseRNG& rng,
45  Eigen::MatrixXd& inv_e_metric)
46  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng,
47  inv_e_metric),
48  depth_(0), max_depth_(5), max_deltaH_(1000),
49  n_leapfrog_(0), divergent_(false), energy_(0) {
50  }
51 
53 
54  void set_metric(const Eigen::MatrixXd& inv_e_metric) {
55  this->z_.set_metric(inv_e_metric);
56  }
57 
58  void set_metric(const Eigen::VectorXd& inv_e_metric) {
59  this->z_.set_metric(inv_e_metric);
60  }
61 
62  void set_max_depth(int d) {
63  if (d > 0)
64  max_depth_ = d;
65  }
66 
67  void set_max_delta(double d) {
68  max_deltaH_ = d;
69  }
70 
71  int get_max_depth() { return this->max_depth_; }
72  double get_max_delta() { return this->max_deltaH_; }
73 
74  sample
75  transition(sample& init_sample, callbacks::logger& logger) {
76  // Initialize the algorithm
77  this->sample_stepsize();
78 
79  this->seed(init_sample.cont_params());
80 
81  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
82  this->hamiltonian_.init(this->z_, logger);
83 
84  ps_point z_plus(this->z_);
85  ps_point z_minus(z_plus);
86 
87  ps_point z_sample(z_plus);
88  ps_point z_propose(z_plus);
89 
90  Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
91  Eigen::VectorXd p_sharp_dummy = p_sharp_plus;
92  Eigen::VectorXd p_sharp_minus = p_sharp_plus;
93  Eigen::VectorXd rho = this->z_.p;
94 
95  double log_sum_weight = 0; // log(exp(H0 - H0))
96  double H0 = this->hamiltonian_.H(this->z_);
97  int n_leapfrog = 0;
98  double sum_metro_prob = 0;
99 
100  // Build a trajectory until the NUTS criterion is no longer satisfied
101  this->depth_ = 0;
102  this->divergent_ = false;
103 
104  while (this->depth_ < this->max_depth_) {
105  // Build a new subtree in a random direction
106  Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size());
107  bool valid_subtree = false;
108  double log_sum_weight_subtree
109  = -std::numeric_limits<double>::infinity();
110 
111  if (this->rand_uniform_() > 0.5) {
112  this->z_.ps_point::operator=(z_plus);
113  valid_subtree
114  = build_tree(this->depth_, z_propose,
115  p_sharp_dummy, p_sharp_plus, rho_subtree,
116  H0, 1, n_leapfrog,
117  log_sum_weight_subtree, sum_metro_prob,
118  logger);
119  z_plus.ps_point::operator=(this->z_);
120  } else {
121  this->z_.ps_point::operator=(z_minus);
122  valid_subtree
123  = build_tree(this->depth_, z_propose,
124  p_sharp_dummy, p_sharp_minus, rho_subtree,
125  H0, -1, n_leapfrog,
126  log_sum_weight_subtree, sum_metro_prob,
127  logger);
128  z_minus.ps_point::operator=(this->z_);
129  }
130 
131  if (!valid_subtree) break;
132 
133  // Sample from an accepted subtree
134  ++(this->depth_);
135 
136  if (log_sum_weight_subtree > log_sum_weight) {
137  z_sample = z_propose;
138  } else {
139  double accept_prob
140  = std::exp(log_sum_weight_subtree - log_sum_weight);
141  if (this->rand_uniform_() < accept_prob)
142  z_sample = z_propose;
143  }
144 
145  log_sum_weight
146  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
147 
148  // Break when NUTS criterion is no longer satisfied
149  rho += rho_subtree;
150  if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho))
151  break;
152  }
153 
154  this->n_leapfrog_ = n_leapfrog;
155 
156  // Compute average acceptance probabilty across entire trajectory,
157  // even over subtrees that may have been rejected
158  double accept_prob
159  = sum_metro_prob / static_cast<double>(n_leapfrog);
160 
161  this->z_.ps_point::operator=(z_sample);
162  this->energy_ = this->hamiltonian_.H(this->z_);
163  return sample(this->z_.q, -this->z_.V, accept_prob);
164  }
165 
166  void get_sampler_param_names(std::vector<std::string>& names) {
167  names.push_back("stepsize__");
168  names.push_back("treedepth__");
169  names.push_back("n_leapfrog__");
170  names.push_back("divergent__");
171  names.push_back("energy__");
172  }
173 
174  void get_sampler_params(std::vector<double>& values) {
175  values.push_back(this->epsilon_);
176  values.push_back(this->depth_);
177  values.push_back(this->n_leapfrog_);
178  values.push_back(this->divergent_);
179  values.push_back(this->energy_);
180  }
181 
182  virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus,
183  Eigen::VectorXd& p_sharp_plus,
184  Eigen::VectorXd& rho) {
185  return p_sharp_plus.dot(rho) > 0
186  && p_sharp_minus.dot(rho) > 0;
187  }
188 
189  /**
190  * Recursively build a new subtree to completion or until
191  * the subtree becomes invalid. Returns validity of the
192  * resulting subtree.
193  *
194  * @param depth Depth of the desired subtree
195  * @param z_propose State proposed from subtree
196  * @param p_sharp_left p_sharp from left boundary of returned tree
197  * @param p_sharp_right p_sharp from the right boundary of returned tree
198  * @param rho Summed momentum across trajectory
199  * @param H0 Hamiltonian of initial state
200  * @param sign Direction in time to built subtree
201  * @param n_leapfrog Summed number of leapfrog evaluations
202  * @param log_sum_weight Log of summed weights across trajectory
203  * @param sum_metro_prob Summed Metropolis probabilities across trajectory
204  * @param logger Logger for messages
205  */
206  bool build_tree(int depth, ps_point& z_propose,
207  Eigen::VectorXd& p_sharp_left,
208  Eigen::VectorXd& p_sharp_right,
209  Eigen::VectorXd& rho,
210  double H0, double sign, int& n_leapfrog,
211  double& log_sum_weight, double& sum_metro_prob,
212  callbacks::logger& logger) {
213  // Base case
214  if (depth == 0) {
215  this->integrator_.evolve(this->z_, this->hamiltonian_,
216  sign * this->epsilon_,
217  logger);
218  ++n_leapfrog;
219 
220  double h = this->hamiltonian_.H(this->z_);
221  if (boost::math::isnan(h))
222  h = std::numeric_limits<double>::infinity();
223 
224  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
225 
226  log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
227 
228  if (H0 - h > 0)
229  sum_metro_prob += 1;
230  else
231  sum_metro_prob += std::exp(H0 - h);
232 
233  z_propose = this->z_;
234  rho += this->z_.p;
235 
236  p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
237  p_sharp_right = p_sharp_left;
238 
239  return !this->divergent_;
240  }
241  // General recursion
242  Eigen::VectorXd p_sharp_dummy(this->z_.p.size());
243 
244  // Build the left subtree
245  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
246  Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size());
247 
248  bool valid_left
249  = build_tree(depth - 1, z_propose,
250  p_sharp_left, p_sharp_dummy, rho_left,
251  H0, sign, n_leapfrog,
252  log_sum_weight_left, sum_metro_prob,
253  logger);
254 
255  if (!valid_left) return false;
256 
257  // Build the right subtree
258  ps_point z_propose_right(this->z_);
259 
260  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
261  Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size());
262 
263  bool valid_right
264  = build_tree(depth - 1, z_propose_right,
265  p_sharp_dummy, p_sharp_right, rho_right,
266  H0, sign, n_leapfrog,
267  log_sum_weight_right, sum_metro_prob,
268  logger);
269 
270  if (!valid_right) return false;
271 
272  // Multinomial sample from right subtree
273  double log_sum_weight_subtree
274  = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
275  log_sum_weight
276  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
277 
278  if (log_sum_weight_right > log_sum_weight_subtree) {
279  z_propose = z_propose_right;
280  } else {
281  double accept_prob
282  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
283  if (this->rand_uniform_() < accept_prob)
284  z_propose = z_propose_right;
285  }
286 
287  Eigen::VectorXd rho_subtree = rho_left + rho_right;
288  rho += rho_subtree;
289 
290  return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
291  }
292 
293  int depth_;
295  double max_deltaH_;
296 
299  double energy_;
300  };
301 
302  } // mcmc
303 } // stan
304 #endif
base_nuts(const Model &model, BaseRNG &rng)
Definition: base_nuts.hpp:24
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:182
void get_sampler_params(std::vector< double > &values)
Definition: base_nuts.hpp:174
void set_max_delta(double d)
Definition: base_nuts.hpp:67
rosenbrock_model_namespace::rosenbrock_model Model
fvar< T > log_sum_exp(const std::vector< fvar< T > > &v)
Definition: log_sum_exp.hpp:12
bool isnan(const stan::math::var &v)
Definition: boost_isnan.hpp:20
double cont_params(int k) const
Definition: sample.hpp:24
bool build_tree(int depth, ps_point &z_propose, Eigen::VectorXd &p_sharp_left, Eigen::VectorXd &p_sharp_right, Eigen::VectorXd &rho, double H0, double sign, int &n_leapfrog, double &log_sum_weight, double &sum_metro_prob, callbacks::logger &logger)
Definition: base_nuts.hpp:206
void set_metric(const Eigen::MatrixXd &inv_e_metric)
Definition: base_nuts.hpp:54
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:10
Float_t d
Definition: plot.C:236
void set_metric(const Eigen::VectorXd &inv_e_metric)
Definition: base_nuts.hpp:58
sample transition(sample &init_sample, callbacks::logger &logger)
Definition: base_nuts.hpp:75
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:73
virtual bool compute_criterion(Eigen::VectorXd &p_sharp_minus, Eigen::VectorXd &p_sharp_plus, Eigen::VectorXd &rho)
Definition: base_nuts.hpp:182
base_nuts(const Model &model, BaseRNG &rng, Eigen::MatrixXd &inv_e_metric)
Definition: base_nuts.hpp:44
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:189
void set_max_depth(int d)
Definition: base_nuts.hpp:62
void Zero()
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_nuts.hpp:166
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:184
const XML_Char XML_Content * model
Definition: expat.h:151
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:183
base_nuts(const Model &model, BaseRNG &rng, Eigen::VectorXd &inv_e_metric)
Definition: base_nuts.hpp:33
def sign(x)
Definition: canMan.py:197