mcmc_writer.hpp
Go to the documentation of this file.
1 #ifndef STAN_SERVICES_UTIL_MCMC_WRITER_HPP
2 #define STAN_SERVICES_UTIL_MCMC_WRITER_HPP
3 
7 #include <stan/mcmc/sample.hpp>
9 #include <iomanip>
10 #include <limits>
11 #include <sstream>
12 #include <string>
13 #include <vector>
14 
15 namespace stan {
16 namespace services {
17 namespace util {
18 
19 /**
20  * mcmc_writer writes out headers and samples
21  *
22  * @tparam Model Model class
23  */
24 class mcmc_writer {
25  private:
29 
30  public:
34  /**
35  * Constructor.
36  *
37  * @param[in,out] sample_writer samples are "written" to this stream
38  * @param[in,out] diagnostic_writer diagnostic info is "written" to this
39  * stream
40  * @param[in,out] logger messages are written through the logger
41  */
43  callbacks::writer& diagnostic_writer,
44  callbacks::logger& logger)
45  : sample_writer_(sample_writer),
46  diagnostic_writer_(diagnostic_writer),
47  logger_(logger),
48  num_sample_params_(0),
49  num_sampler_params_(0),
50  num_model_params_(0) {
51  }
52 
53  /**
54  * Outputs parameter string names. First outputs the names stored in
55  * the sample object (stan::mcmc::sample), then uses the sampler
56  * provided to output sampler specific names, then adds the model
57  * constrained parameter names.
58  *
59  * The names are written to the sample_stream as comma separated values
60  * with a newline at the end.
61  *
62  * @param[in] sample a sample (unconstrained) that works with the model
63  * @param[in] sampler a stan::mcmc::base_mcmc object
64  * @param[in] model the model
65  */
66  template <class Model>
68  stan::mcmc::base_mcmc& sampler,
69  Model& model) {
70  std::vector<std::string> names;
71 
72  sample.get_sample_param_names(names);
73  num_sample_params_ = names.size();
74 
75  sampler.get_sampler_param_names(names);
76  num_sampler_params_ = names.size() - num_sample_params_;
77 
78  model.constrained_param_names(names, true, true);
79  num_model_params_
80  = names.size() - num_sample_params_ - num_sampler_params_;
81 
82  sample_writer_(names);
83  }
84 
85  /**
86  * Outputs samples. First outputs the values of the sample params
87  * from a stan::mcmc::sample, then outputs the values of the sampler
88  * params from a stan::mcmc::base_mcmc, then finally outputs the values
89  * of the model.
90  *
91  * The samples are written to the sample_stream as comma separated
92  * values with a newline at the end.
93  *
94  * @param[in,out] rng random number generator (used by
95  * model.write_array())
96  * @param[in] sample the sample in constrained space
97  * @param[in] sampler the sampler
98  * @param[in] model the model
99  */
100  template <class Model, class RNG>
101  void write_sample_params(RNG& rng,
102  stan::mcmc::sample& sample,
103  stan::mcmc::base_mcmc& sampler,
104  Model& model) {
105  std::vector<double> values;
106 
107  sample.get_sample_params(values);
108  sampler.get_sampler_params(values);
109 
110  std::vector<double> model_values;
111  std::vector<int> params_i;
112  std::stringstream ss;
113  try {
114  std::vector<double> cont_params(sample.cont_params().data(),
115  sample.cont_params().data()
116  + sample.cont_params().size());
117  model.write_array(rng,
118  cont_params,
119  params_i,
120  model_values,
121  true, true,
122  &ss);
123  } catch (const std::exception& e) {
124  if (ss.str().length() > 0)
125  logger_.info(ss);
126  ss.str("");
127  logger_.info(e.what());
128  }
129  if (ss.str().length() > 0)
130  logger_.info(ss);
131 
132  if (model_values.size() > 0)
133  values.insert(values.end(), model_values.begin(), model_values.end());
134  if (model_values.size() < num_model_params_)
135  values.insert(values.end(),
136  num_model_params_ - model_values.size(),
137  std::numeric_limits<double>::quiet_NaN());
138 
139 
140  sample_writer_(values);
141  }
142 
143  /**
144  * Prints additional info to the streams
145  *
146  * Prints to the sample stream
147  *
148  * @param[in] sampler sampler
149  */
151  sample_writer_("Adaptation terminated");
152  }
153 
154  /**
155  * Print diagnostic names
156  *
157  * @param[in] sample unconstrained sample
158  * @param[in] sampler sampler
159  * @param[in] model model
160  */
161  template <class Model>
163  stan::mcmc::base_mcmc& sampler,
164  Model& model) {
165  std::vector<std::string> names;
166 
167  sample.get_sample_param_names(names);
168  sampler.get_sampler_param_names(names);
169 
170  std::vector<std::string> model_names;
171  model.unconstrained_param_names(model_names, false, false);
172 
173  sampler.get_sampler_diagnostic_names(model_names, names);
174 
175  diagnostic_writer_(names);
176  }
177 
178  /**
179  * Print diagnostic params to the diagnostic stream.
180  *
181  * @param[in] sample unconstrained sample
182  * @param[in] sampler sampler
183  */
185  stan::mcmc::base_mcmc& sampler) {
186  std::vector<double> values;
187 
188  sample.get_sample_params(values);
189  sampler.get_sampler_params(values);
190  sampler.get_sampler_diagnostics(values);
191 
192  diagnostic_writer_(values);
193  }
194 
195  /**
196  * Internal method
197  *
198  * Prints timing information
199  *
200  * @param[in] warmDeltaT warmup time in seconds
201  * @param[in] sampleDeltaT sample time in seconds
202  * @param[in,out] writer output stream
203  */
204  void write_timing(double warmDeltaT, double sampleDeltaT,
205  callbacks::writer& writer) {
206  std::string title(" Elapsed Time: ");
207  writer();
208 
209  std::stringstream ss1;
210  ss1 << title << warmDeltaT << " seconds (Warm-up)";
211  writer(ss1.str());
212 
213  std::stringstream ss2;
214  ss2 << std::string(title.size(), ' ') << sampleDeltaT
215  << " seconds (Sampling)";
216  writer(ss2.str());
217 
218  std::stringstream ss3;
219  ss3 << std::string(title.size(), ' ')
220  << warmDeltaT + sampleDeltaT
221  << " seconds (Total)";
222  writer(ss3.str());
223 
224  writer();
225  }
226 
227  /**
228  * Internal method
229  *
230  * Logs timing information
231  *
232  * @param[in] warmDeltaT warmup time in seconds
233  * @param[in] sampleDeltaT sample time in seconds
234  */
235  void log_timing(double warmDeltaT, double sampleDeltaT) {
236  std::string title(" Elapsed Time: ");
237  logger_.info("");
238 
239  std::stringstream ss1;
240  ss1 << title << warmDeltaT << " seconds (Warm-up)";
241  logger_.info(ss1);
242 
243  std::stringstream ss2;
244  ss2 << std::string(title.size(), ' ') << sampleDeltaT
245  << " seconds (Sampling)";
246  logger_.info(ss2);
247 
248  std::stringstream ss3;
249  ss3 << std::string(title.size(), ' ')
250  << warmDeltaT + sampleDeltaT
251  << " seconds (Total)";
252  logger_.info(ss3);
253 
254  logger_.info("");
255  }
256 
257  /**
258  * Print timing information to all streams
259  *
260  * @param[in] warmDeltaT warmup time (sec)
261  * @param[in] sampleDeltaT sample time (sec)
262  */
263  void write_timing(double warmDeltaT, double sampleDeltaT) {
264  write_timing(warmDeltaT, sampleDeltaT, sample_writer_);
265  write_timing(warmDeltaT, sampleDeltaT, diagnostic_writer_);
266  log_timing(warmDeltaT, sampleDeltaT);
267  }
268 };
269 
270 }
271 }
272 }
273 #endif
Filter events based on their run/event numbers.
static void get_sample_param_names(std::vector< std::string > &names)
Definition: sample.hpp:44
void write_sample_params(RNG &rng, stan::mcmc::sample &sample, stan::mcmc::base_mcmc &sampler, Model &model)
callbacks::writer & sample_writer_
Definition: mcmc_writer.hpp:26
virtual void get_sampler_diagnostic_names(std::vector< std::string > &model_names, std::vector< std::string > &names)
Definition: base_mcmc.hpp:31
void log_timing(double warmDeltaT, double sampleDeltaT)
void write_diagnostic_params(stan::mcmc::sample &sample, stan::mcmc::base_mcmc &sampler)
Float_t ss
Definition: plot.C:24
void write_adapt_finish(stan::mcmc::base_mcmc &sampler)
virtual void get_sampler_params(std::vector< double > &values)
Definition: base_mcmc.hpp:25
::xsd::cxx::tree::exception< char > exception
Definition: Database.h:225
void get_sample_params(std::vector< double > &values)
Definition: sample.hpp:49
rosenbrock_model_namespace::rosenbrock_model Model
virtual void get_sampler_diagnostics(std::vector< double > &values)
Definition: base_mcmc.hpp:34
void write_timing(double warmDeltaT, double sampleDeltaT)
double cont_params(int k) const
Definition: sample.hpp:24
virtual void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_mcmc.hpp:23
void write_sample_names(stan::mcmc::sample &sample, stan::mcmc::base_mcmc &sampler, Model &model)
Definition: mcmc_writer.hpp:67
void write_diagnostic_names(stan::mcmc::sample sample, stan::mcmc::base_mcmc &sampler, Model &model)
virtual void info(const std::string &message)
Definition: logger.hpp:47
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
mcmc_writer(callbacks::writer &sample_writer, callbacks::writer &diagnostic_writer, callbacks::logger &logger)
Definition: mcmc_writer.hpp:42
callbacks::writer & diagnostic_writer_
Definition: mcmc_writer.hpp:27
Float_t e
Definition: plot.C:35
const XML_Char XML_Content * model
Definition: expat.h:151
void write_timing(double warmDeltaT, double sampleDeltaT, callbacks::writer &writer)