base_nuts_classic.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_CLASSIC_BASE_NUTS_CLASSIC_HPP
2 #define STAN_MCMC_HMC_NUTS_CLASSIC_BASE_NUTS_CLASSIC_HPP
3
5 #include <boost/math/special_functions/fpclassify.hpp>
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <string>
12 #include <vector>
13
14 namespace stan {
15  namespace mcmc {
16
17  struct nuts_util {
18  // Constants through each recursion
19  double log_u;
20  double H0;
21  int sign;
22
23  // Aggregators through each recursion
24  int n_tree;
25  double sum_prob;
26  bool criterion;
27
28  // just to guarantee bool initializes to valid value
29  nuts_util() : criterion(false) { }
30  };
31
32  // The No-U-Turn Sampler (NUTS) with the
33  // original slice sampler implementation
34  template <class Model, template<class, class> class Hamiltonian,
35  template<class> class Integrator, class BaseRNG>
37  public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
38  public:
39  base_nuts_classic(const Model& model, BaseRNG& rng):
40  base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
41  depth_(0), max_depth_(5), max_delta_(1000),
42  n_leapfrog_(0), divergent_(0), energy_(0) {
43  }
44
46
47  void set_max_depth(int d) {
48  if (d > 0)
49  max_depth_ = d;
50  }
51
52  void set_max_delta(double d) {
53  max_delta_ = d;
54  }
55
56  int get_max_depth() { return this->max_depth_; }
57  double get_max_delta() { return this->max_delta_; }
58
59  sample
60  transition(sample& init_sample, callbacks::logger& logger) {
61  // Initialize the algorithm
62  this->sample_stepsize();
63
65
66  this->seed(init_sample.cont_params());
67
68  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
69  this->hamiltonian_.init(this->z_, logger);
70
71  ps_point z_plus(this->z_);
72  ps_point z_minus(z_plus);
73
74  ps_point z_sample(z_plus);
75  ps_point z_propose(z_plus);
76
77  int n_cont = init_sample.cont_params().size();
78
79  Eigen::VectorXd rho_init = this->z_.p;
80  Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero();
81  Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero();
82
83  util.H0 = this->hamiltonian_.H(this->z_);
84
85  // Sample the slice variable
86  util.log_u = std::log(this->rand_uniform_());
87
88  // Build a balanced binary tree until the NUTS criterion fails
89  util.criterion = true;
90  int n_valid = 0;
91
92  this->depth_ = 0;
93  this->divergent_ = 0;
94
95  util.n_tree = 0;
96  util.sum_prob = 0;
97
98  while (util.criterion && (this->depth_ <= this->max_depth_)) {
99  // Randomly sample a direction in time
100  ps_point* z = 0;
101  Eigen::VectorXd* rho = 0;
102
103  if (this->rand_uniform_() > 0.5) {
104  z = &z_plus;
105  rho = &rho_plus;
106  util.sign = 1;
107  } else {
108  z = &z_minus;
109  rho = &rho_minus;
110  util.sign = -1;
111  }
112
113  // And build a new subtree in that direction
114  this->z_.ps_point::operator=(*z);
115
116  int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util,
117  logger);
118  ++(this->depth_);
119
120  *z = this->z_;
121
122  // Metropolis-Hastings sample the fresh subtree
123  if (!util.criterion)
124  break;
125
126  double subtree_prob = 0;
127
128  if (n_valid) {
129  subtree_prob = static_cast<double>(n_valid_subtree) /
130  static_cast<double>(n_valid);
131  } else {
132  subtree_prob = n_valid_subtree ? 1 : 0;
133  }
134
135  if (this->rand_uniform_() < subtree_prob)
136  z_sample = z_propose;
137
138  n_valid += n_valid_subtree;
139
140  // Check validity of completed tree
141  this->z_.ps_point::operator=(z_plus);
142  Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus;
143
144  util.criterion = compute_criterion(z_minus, this->z_, delta_rho);
145  }
146
147  this->n_leapfrog_ = util.n_tree;
148
149  double accept_prob = util.sum_prob / static_cast<double>(util.n_tree);
150
151  this->z_.ps_point::operator=(z_sample);
152  this->energy_ = this->hamiltonian_.H(this->z_);
153  return sample(this->z_.q, - this->z_.V, accept_prob);
154  }
155
156  void get_sampler_param_names(std::vector<std::string>& names) {
157  names.push_back("stepsize__");
158  names.push_back("treedepth__");
159  names.push_back("n_leapfrog__");
160  names.push_back("divergent__");
161  names.push_back("energy__");
162  }
163
164  void get_sampler_params(std::vector<double>& values) {
165  values.push_back(this->epsilon_);
166  values.push_back(this->depth_);
167  values.push_back(this->n_leapfrog_);
168  values.push_back(this->divergent_);
169  values.push_back(this->energy_);
170  }
171
172  virtual bool compute_criterion(ps_point& start,
173  typename Hamiltonian<Model, BaseRNG>
174  ::PointType& finish,
175  Eigen::VectorXd& rho) = 0;
176
177  // Returns number of valid points in the completed subtree
178  int build_tree(int depth, Eigen::VectorXd& rho,
179  ps_point* z_init_parent, ps_point& z_propose,
180  nuts_util& util,
181  callbacks::logger& logger) {
182  // Base case
183  if (depth == 0) {
184  this->integrator_.evolve(this->z_, this->hamiltonian_,
185  util.sign * this->epsilon_,
186  logger);
187  rho += this->z_.p;
188
189  if (z_init_parent) *z_init_parent = this->z_;
190  z_propose = this->z_;
191
192  double h = this->hamiltonian_.H(this->z_);
193  if (boost::math::isnan(h))
194  h = std::numeric_limits<double>::infinity();
195
196  util.criterion = util.log_u + (h - util.H0) < this->max_delta_;
197  if (!util.criterion) ++(this->divergent_);
198
199  util.sum_prob += std::min(1.0, std::exp(util.H0 - h));
200  util.n_tree += 1;
201
202  return (util.log_u + (h - util.H0) < 0);
203
204  } else {
205  // General recursion
206  Eigen::VectorXd left_subtree_rho(rho.size());
207  left_subtree_rho.setZero();
208  ps_point z_init(this->z_);
209
210  int n1 = build_tree(depth - 1, left_subtree_rho, &z_init,
211  z_propose, util,
212  logger);
213
214  if (z_init_parent) *z_init_parent = z_init;
215
216  if (!util.criterion) return 0;
217
218  Eigen::VectorXd right_subtree_rho(rho.size());
219  right_subtree_rho.setZero();
220  ps_point z_propose_right(z_init);
221
222  int n2 = build_tree(depth - 1, right_subtree_rho, 0,
223  z_propose_right, util,
224  logger);
225
226  double accept_prob = static_cast<double>(n2) /
227  static_cast<double>(n1 + n2);
228
229  if ( util.criterion && (this->rand_uniform_() < accept_prob) )
230  z_propose = z_propose_right;
231
232  Eigen::VectorXd& subtree_rho = left_subtree_rho;
233  subtree_rho += right_subtree_rho;
234
235  rho += subtree_rho;
236
237  util.criterion &= compute_criterion(z_init, this->z_, subtree_rho);
238
239  return n1 + n2;
240  }
241  }
242
243  int depth_;
245  double max_delta_;
246
249  double energy_;
250  };
251
252  } // mcmc
253 } // stan
254 #endif
Filter events based on their run/event numbers.
sample transition(sample &init_sample, callbacks::logger &logger)
rosenbrock_model_namespace::rosenbrock_model Model
bool isnan(const stan::math::var &v)
Definition: boost_isnan.hpp:20
double cont_params(int k) const
Definition: sample.hpp:24
unsigned int seed
Definition: runWimpSim.h:102
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:10
Float_t d
Definition: plot.C:236
base_nuts_classic(const Model &model, BaseRNG &rng)
z
Definition: test.py:28
void get_sampler_param_names(std::vector< std::string > &names)
T min(const caf::Proxy< T > &a, T b)
int build_tree(int depth, Eigen::VectorXd &rho, ps_point *z_init_parent, ps_point &z_propose, nuts_util &util, callbacks::logger &logger)
const XML_Char XML_Content * model
Definition: expat.h:151
void get_sampler_params(std::vector< double > &values)