base_xhmc.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_XHMC_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_XHMC_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <string>
12 #include <vector>
13 
14 namespace stan {
15  namespace mcmc {
16  /**
17  * a1 and a2 are running averages of the form
18  * \f$ a1 = ( \sum_{n \in N1} w_{n} f_{n} )
19  * / ( \sum_{n \in N1} w_{n} ) \f$
20  * \f$ a2 = ( \sum_{n \in N2} w_{n} f_{n} )
21  * / ( \sum_{n \in N2} w_{n} ) \f$
22  * and the weights are the respective normalizing constants
23  * \f$ w1 = \sum_{n \in N1} w_{n} \f$
24  * \f$ w2 = \sum_{n \in N2} w_{n}. \f$
25  *
26  * This function returns the pooled average
27  * \f$ sum_a = ( \sum_{n \in N1 \cup N2} w_{n} f_{n} )
28  * / ( \sum_{n \in N1 \cup N2} w_{n} ) \f$
29  * and the pooled weights
30  * \f$ log_sum_w = log(w1 + w2). \f$
31  *
32  * @param a1 First running average, f1 / w1
33  * @param log_w1 Log of first summed weight
34  * @param a2 Second running average
35  * @param log_w2 Log of second summed weight
36  * @param sum_a Average of input running averages
37  * @param log_sum_w Log of summed input weights
38  */
39  void stable_sum(double a1, double log_w1, double a2, double log_w2,
40  double& sum_a, double& log_sum_w) {
41  if (log_w2 > log_w1) {
42  double e = std::exp(log_w1 - log_w2);
43  sum_a = (e * a1 + a2) / (1 + e);
44  log_sum_w = log_w2 + std::log(1 + e);
45  } else {
46  double e = std::exp(log_w2 - log_w1);
47  sum_a = (a1 + e * a2) / (1 + e);
48  log_sum_w = log_w1 + std::log(1 + e);
49  }
50  }
51 
52  /**
53  * Exhaustive Hamiltonian Monte Carlo (XHMC) with multinomial sampling.
54  * See http://arxiv.org/abs/1601.00225.
55  */
56  template <class Model, template<class, class> class Hamiltonian,
57  template<class> class Integrator, class BaseRNG>
58  class base_xhmc : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
59  public:
60  base_xhmc(const Model& model, BaseRNG& rng)
61  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
62  depth_(0), max_depth_(5), max_deltaH_(1000), x_delta_(0.1),
63  n_leapfrog_(0), divergent_(0), energy_(0) {
64  }
65 
67 
68  void set_max_depth(int d) {
69  if (d > 0)
70  max_depth_ = d;
71  }
72 
73  void set_max_deltaH(double d) {
74  max_deltaH_ = d;
75  }
76 
77  void set_x_delta(double d) {
78  if (d > 0)
79  x_delta_ = d;
80  }
81 
82  int get_max_depth() { return this->max_depth_; }
83  double get_max_deltaH() { return this->max_deltaH_; }
84  double get_x_delta() { return this->x_delta_; }
85 
86  sample
87  transition(sample& init_sample, callbacks::logger& logger) {
88  // Initialize the algorithm
89  this->sample_stepsize();
90 
91  this->seed(init_sample.cont_params());
92 
93  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
94  this->hamiltonian_.init(this->z_, logger);
95 
96  ps_point z_plus(this->z_);
97  ps_point z_minus(z_plus);
98 
99  ps_point z_sample(z_plus);
100  ps_point z_propose(z_plus);
101 
102  double ave = this->hamiltonian_.dG_dt(this->z_, logger);
103  double log_sum_weight = 0; // log(exp(H0 - H0))
104 
105  double H0 = this->hamiltonian_.H(this->z_);
106  int n_leapfrog = 0;
107  double sum_metro_prob = 1; // exp(H0 - H0)
108 
109  // Build a trajectory until the NUTS criterion is no longer satisfied
110  this->depth_ = 0;
111  this->divergent_ = 0;
112 
113  while (this->depth_ < this->max_depth_) {
114  // Build a new subtree in a random direction
115  bool valid_subtree = false;
116  double ave_subtree = 0;
117  double log_sum_weight_subtree
118  = -std::numeric_limits<double>::infinity();
119 
120  if (this->rand_uniform_() > 0.5) {
121  this->z_.ps_point::operator=(z_plus);
122  valid_subtree
123  = build_tree(this->depth_, z_propose,
124  ave_subtree, log_sum_weight_subtree,
125  H0, 1, n_leapfrog, sum_metro_prob,
126  logger);
127  z_plus.ps_point::operator=(this->z_);
128  } else {
129  this->z_.ps_point::operator=(z_minus);
130  valid_subtree
131  = build_tree(this->depth_, z_propose,
132  ave_subtree, log_sum_weight_subtree,
133  H0, -1, n_leapfrog, sum_metro_prob,
134  logger);
135  z_minus.ps_point::operator=(this->z_);
136  }
137 
138  if (!valid_subtree) break;
139  stable_sum(ave, log_sum_weight,
140  ave_subtree, log_sum_weight_subtree,
141  ave, log_sum_weight);
142 
143  // Sample from an accepted subtree
144  ++(this->depth_);
145 
146  double accept_prob
147  = std::exp(log_sum_weight_subtree - log_sum_weight);
148  if (this->rand_uniform_() < accept_prob)
149  z_sample = z_propose;
150 
151  // Break if exhaustion criterion is satisfied
152  if (std::fabs(ave) < x_delta_)
153  break;
154  }
155 
156  this->n_leapfrog_ = n_leapfrog;
157 
158  // Compute average acceptance probabilty across entire trajectory,
159  // even over subtrees that may have been rejected
160  double accept_prob
161  = sum_metro_prob / static_cast<double>(n_leapfrog + 1);
162 
163  this->z_.ps_point::operator=(z_sample);
164  this->energy_ = this->hamiltonian_.H(this->z_);
165  return sample(this->z_.q, -this->z_.V, accept_prob);
166  }
167 
168  void get_sampler_param_names(std::vector<std::string>& names) {
169  names.push_back("stepsize__");
170  names.push_back("treedepth__");
171  names.push_back("n_leapfrog__");
172  names.push_back("divergent__");
173  names.push_back("energy__");
174  }
175 
176  void get_sampler_params(std::vector<double>& values) {
177  values.push_back(this->epsilon_);
178  values.push_back(this->depth_);
179  values.push_back(this->n_leapfrog_);
180  values.push_back(this->divergent_);
181  values.push_back(this->energy_);
182  }
183 
184  /**
185  * Recursively build a new subtree to completion or until
186  * the subtree becomes invalid. Returns validity of the
187  * resulting subtree.
188  *
189  * @param depth Depth of the desired subtree
190  * @param z_propose State proposed from subtree
191  * @param ave Weighted average of dG/dt across trajectory
192  * @param log_sum_weight Log of summed weights across trajectory
193  * @param H0 Hamiltonian of initial state
194  * @param sign Direction in time to built subtree
195  * @param n_leapfrog Summed number of leapfrog evaluations
196  * @param sum_metro_prob Summed Metropolis probabilities across trajectory
197  * @param logger Logger for messages
198  */
199  int build_tree(int depth, ps_point& z_propose,
200  double& ave, double& log_sum_weight,
201  double H0, double sign, int& n_leapfrog,
202  double& sum_metro_prob,
203  callbacks::logger& logger) {
204  // Base case
205  if (depth == 0) {
206  this->integrator_.evolve(this->z_, this->hamiltonian_,
207  sign * this->epsilon_,
208  logger);
209  ++n_leapfrog;
210 
211  double h = this->hamiltonian_.H(this->z_);
212  if (boost::math::isnan(h))
213  h = std::numeric_limits<double>::infinity();
214 
215  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
216 
217  double dG_dt = this->hamiltonian_.dG_dt(this->z_, logger);
218 
219  stable_sum(ave, log_sum_weight,
220  dG_dt, H0 - h,
221  ave, log_sum_weight);
222 
223  if (H0 - h > 0)
224  sum_metro_prob += 1;
225  else
226  sum_metro_prob += std::exp(H0 - h);
227 
228  z_propose = this->z_;
229 
230  return !this->divergent_;
231  }
232  // General recursion
233 
234  // Build the left subtree
235  double ave_left = 0;
236  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
237 
238  bool valid_left
239  = build_tree(depth - 1, z_propose,
240  ave_left, log_sum_weight_left,
241  H0, sign, n_leapfrog, sum_metro_prob,
242  logger);
243 
244  if (!valid_left) return false;
245  stable_sum(ave, log_sum_weight,
246  ave_left, log_sum_weight_left,
247  ave, log_sum_weight);
248 
249  // Build the right subtree
250  ps_point z_propose_right(this->z_);
251  double ave_right = 0;
252  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
253 
254  bool valid_right
255  = build_tree(depth - 1, z_propose_right,
256  ave_right, log_sum_weight_right,
257  H0, sign, n_leapfrog, sum_metro_prob,
258  logger);
259 
260  if (!valid_right) return false;
261  stable_sum(ave, log_sum_weight,
262  ave_right, log_sum_weight_right,
263  ave, log_sum_weight);
264 
265  // Multinomial sample from right subtree
266  double ave_subtree;
267  double log_sum_weight_subtree;
268  stable_sum(ave_left, log_sum_weight_left,
269  ave_right, log_sum_weight_right,
270  ave_subtree, log_sum_weight_subtree);
271 
272  double accept_prob
273  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
274  if (this->rand_uniform_() < accept_prob)
275  z_propose = z_propose_right;
276 
277  return std::fabs(ave_subtree) >= x_delta_;
278  }
279 
280  int depth_;
282  double max_deltaH_;
283  double x_delta_;
284 
287  double energy_;
288  };
289 
290  } // mcmc
291 } // stan
292 #endif
sample transition(sample &init_sample, callbacks::logger &logger)
Definition: base_xhmc.hpp:87
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:182
base_xhmc(const Model &model, BaseRNG &rng)
Definition: base_xhmc.hpp:60
TH1F * a2
Definition: f2_nu.C:545
fvar< T > fabs(const fvar< T > &x)
Definition: fabs.hpp:15
int build_tree(int depth, ps_point &z_propose, double &ave, double &log_sum_weight, double H0, double sign, int &n_leapfrog, double &sum_metro_prob, callbacks::logger &logger)
Definition: base_xhmc.hpp:199
void set_max_depth(int d)
Definition: base_xhmc.hpp:68
rosenbrock_model_namespace::rosenbrock_model Model
bool isnan(const stan::math::var &v)
Definition: boost_isnan.hpp:20
double cont_params(int k) const
Definition: sample.hpp:24
TH1F * a1
Definition: f2_nu.C:476
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:10
Float_t d
Definition: plot.C:236
void stable_sum(double a1, double log_w1, double a2, double log_w2, double &sum_a, double &log_sum_w)
Definition: base_xhmc.hpp:39
void set_x_delta(double d)
Definition: base_xhmc.hpp:77
void get_sampler_params(std::vector< double > &values)
Definition: base_xhmc.hpp:176
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:73
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_xhmc.hpp:168
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:189
void set_max_deltaH(double d)
Definition: base_xhmc.hpp:73
Float_t e
Definition: plot.C:35
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:184
const XML_Char XML_Content * model
Definition: expat.h:151
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:183
def sign(x)
Definition: canMan.py:197