hmc_nuts_dense_e.hpp
Go to the documentation of this file.
1 #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DENSE_E_HPP
2 #define STAN_SERVICES_SAMPLE_HMC_NUTS_DENSE_E_HPP
3 
8 #include <stan/math/prim/mat.hpp>
15 #include <vector>
16 
17 namespace stan {
18  namespace services {
19  namespace sample {
20 
21  /**
22  * Runs HMC with NUTS without adaptation using dense Euclidean metric
23  * with a pre-specified Euclidean metric.
24  *
25  * @tparam Model Model class
26  * @param[in] model Input model to test (with data already instantiated)
27  * @param[in] init var context for initialization
28  * @param[in] init_inv_metric var context exposing an initial dense
29  inverse Euclidean metric (must be positive definite)
30  * @param[in] random_seed random seed for the random number generator
31  * @param[in] chain chain id to advance the pseudo random number generator
32  * @param[in] init_radius radius to initialize
33  * @param[in] num_warmup Number of warmup samples
34  * @param[in] num_samples Number of samples
35  * @param[in] num_thin Number to thin the samples
36  * @param[in] save_warmup Indicates whether to save the warmup iterations
37  * @param[in] refresh Controls the output
38  * @param[in] stepsize initial stepsize for discrete evolution
39  * @param[in] stepsize_jitter uniform random jitter of stepsize
40  * @param[in] max_depth Maximum tree depth
41  * @param[in,out] interrupt Callback for interrupts
42  * @param[in,out] logger Logger for messages
43  * @param[in,out] init_writer Writer callback for unconstrained inits
44  * @param[in,out] sample_writer Writer for draws
45  * @param[in,out] diagnostic_writer Writer for diagnostic information
46  * @return error_codes::OK if successful
47  */
48  template <class Model>
50  stan::io::var_context& init_inv_metric,
51  unsigned int random_seed, unsigned int chain,
52  double init_radius, int num_warmup, int num_samples,
53  int num_thin, bool save_warmup, int refresh,
54  double stepsize, double stepsize_jitter,
55  int max_depth,
56  callbacks::interrupt& interrupt,
57  callbacks::logger& logger,
58  callbacks::writer& init_writer,
59  callbacks::writer& sample_writer,
60  callbacks::writer& diagnostic_writer) {
61  boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
62 
63  std::vector<int> disc_vector;
64  std::vector<double> cont_vector
65  = util::initialize(model, init, rng, init_radius, true,
66  logger, init_writer);
67 
68  Eigen::MatrixXd inv_metric;
69  try {
70  inv_metric =
71  util::read_dense_inv_metric(init_inv_metric, model.num_params_r(),
72  logger);
73  util::validate_dense_inv_metric(inv_metric, logger);
74  } catch (const std::domain_error& e) {
75  return error_codes::CONFIG;
76  }
77 
79  sampler(model, rng);
80 
81  sampler.set_metric(inv_metric);
82 
83  sampler.set_nominal_stepsize(stepsize);
84  sampler.set_stepsize_jitter(stepsize_jitter);
85  sampler.set_max_depth(max_depth);
86 
87  util::run_sampler(sampler, model, cont_vector, num_warmup, num_samples,
88  num_thin, refresh, save_warmup, rng, interrupt,
89  logger,
90  sample_writer, diagnostic_writer);
91  return error_codes::OK;
92  }
93 
94  /**
95  * Runs HMC with NUTS without adaptation using dense Euclidean metric,
96  * with identity matrix as initial inv_metric.
97  *
98  * @tparam Model Model class
99  * @param[in] model Input model to test (with data already instantiated)
100  * @param[in] init var context for initialization
101  * @param[in] random_seed random seed for the random number generator
102  * @param[in] chain chain id to advance the pseudo random number generator
103  * @param[in] init_radius radius to initialize
104  * @param[in] num_warmup Number of warmup samples
105  * @param[in] num_samples Number of samples
106  * @param[in] num_thin Number to thin the samples
107  * @param[in] save_warmup Indicates whether to save the warmup iterations
108  * @param[in] refresh Controls the output
109  * @param[in] stepsize initial stepsize for discrete evolution
110  * @param[in] stepsize_jitter uniform random jitter of stepsize
111  * @param[in] max_depth Maximum tree depth
112  * @param[in,out] interrupt Callback for interrupts
113  * @param[in,out] logger Logger for messages
114  * @param[in,out] init_writer Writer callback for unconstrained inits
115  * @param[in,out] sample_writer Writer for draws
116  * @param[in,out] diagnostic_writer Writer for diagnostic information
117  * @return error_codes::OK if successful
118  *
119  */
120  template <class Model>
122  unsigned int random_seed, unsigned int chain,
123  double init_radius, int num_warmup, int num_samples,
124  int num_thin, bool save_warmup, int refresh,
125  double stepsize, double stepsize_jitter,
126  int max_depth,
127  callbacks::interrupt& interrupt,
128  callbacks::logger& logger,
129  callbacks::writer& init_writer,
130  callbacks::writer& sample_writer,
131  callbacks::writer& diagnostic_writer) {
132  stan::io::dump dmp =
133  util::create_unit_e_dense_inv_metric(model.num_params_r());
134  stan::io::var_context& unit_e_metric = dmp;
135 
136  return hmc_nuts_dense_e(model, init, unit_e_metric,
137  random_seed, chain, init_radius, num_warmup,
138  num_samples, num_thin, save_warmup, refresh,
139  stepsize, stepsize_jitter, max_depth,
140  interrupt, logger,
141  init_writer, sample_writer, diagnostic_writer);
142  }
143 
144  }
145  }
146 }
147 #endif
void validate_dense_inv_metric(const Eigen::MatrixXd &inv_metric, callbacks::logger &logger)
stan::io::dump create_unit_e_dense_inv_metric(size_t num_params)
rosenbrock_model_namespace::rosenbrock_model Model
int hmc_nuts_dense_e(Model &model, stan::io::var_context &init, stan::io::var_context &init_inv_metric, unsigned int random_seed, unsigned int chain, double init_radius, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, callbacks::interrupt &interrupt, callbacks::logger &logger, callbacks::writer &init_writer, callbacks::writer &sample_writer, callbacks::writer &diagnostic_writer)
Eigen::MatrixXd read_dense_inv_metric(stan::io::var_context &init_context, size_t num_params, callbacks::logger &logger)
chain
Check that an output directory exists.
std::vector< double > initialize(Model &model, stan::io::var_context &init, RNG &rng, double init_radius, bool print_timing, stan::callbacks::logger &logger, stan::callbacks::writer &init_writer)
Definition: initialize.hpp:68
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
void run_sampler(stan::mcmc::base_mcmc &sampler, Model &model, std::vector< double > &cont_vector, int num_warmup, int num_samples, int num_thin, int refresh, bool save_warmup, RNG &rng, callbacks::interrupt &interrupt, callbacks::logger &logger, callbacks::writer &sample_writer, callbacks::writer &diagnostic_writer)
Definition: run_sampler.hpp:36
Float_t e
Definition: plot.C:35
const XML_Char XML_Content * model
Definition: expat.h:151
boost::ecuyer1988 create_rng(unsigned int seed, unsigned int chain)
Definition: create_rng.hpp:25
void refresh()
Definition: show_event.C:21