hmc_static_diag_e_test.cpp
Go to the documentation of this file.
1 #include <stan/math/prim/mat.hpp>
3 #include <gtest/gtest.h>
5 #include <test/test-models/good/optimization/rosenbrock.hpp>
7 #include <iostream>
8 
9 class ServicesSampleHmcStaticDiagE : public testing::Test {
10 public:
12  : model(context, &model_log) {}
13 
14  std::stringstream model_log;
18  stan_model model;
19 };
20 
22  unsigned int random_seed = 0;
23  unsigned int chain = 1;
24  double init_radius = 0;
25  int num_warmup = 200;
26  int num_samples = 400;
27  int num_thin = 5;
28  bool save_warmup = true;
29  int refresh = 0;
30  double stepsize = 0.1;
31  double stepsize_jitter = 0;
32  double int_time = 8;
34  EXPECT_EQ(interrupt.call_count(), 0);
35 
37  model, context, random_seed, chain, init_radius,
38  num_warmup, num_samples, num_thin, save_warmup, refresh,
39  stepsize, stepsize_jitter, int_time,
40  interrupt, logger, init,
42 
43  EXPECT_EQ(0, return_code);
44 
45  int num_output_lines = (num_warmup+num_samples)/num_thin;
46  EXPECT_EQ(num_warmup+num_samples, interrupt.call_count());
47  EXPECT_EQ(1, parameter.call_count("vector_string"));
48  EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
49  EXPECT_EQ(1, diagnostic.call_count("vector_string"));
50  EXPECT_EQ(num_output_lines, diagnostic.call_count("vector_double"));
51 }
52 
54  unsigned int random_seed = 0;
55  unsigned int chain = 1;
56  double init_radius = 0;
57  int num_warmup = 200;
58  int num_samples = 400;
59  int num_thin = 5;
60  bool save_warmup = true;
61  int refresh = 0;
62  double stepsize = 0.1;
63  double stepsize_jitter = 0;
64  double int_time = 8;
66  EXPECT_EQ(interrupt.call_count(), 0);
67 
69  model, context, random_seed, chain, init_radius,
70  num_warmup, num_samples, num_thin, save_warmup, refresh,
71  stepsize, stepsize_jitter, int_time,
72  interrupt, logger, init,
74 
75  std::vector<std::vector<std::string> > parameter_names;
76  parameter_names = parameter.vector_string_values();
77  std::vector<std::vector<double> > parameter_values;
78  parameter_values = parameter.vector_double_values();
79  std::vector<std::vector<std::string> > diagnostic_names;
80  diagnostic_names = diagnostic.vector_string_values();
81  std::vector<std::vector<double> > diagnostic_values;
82  diagnostic_values = diagnostic.vector_double_values();
83 
84  // Expectations of parameter parameter names.
85  ASSERT_EQ(7, parameter_names[0].size());
86  EXPECT_EQ("lp__", parameter_names[0][0]);
87  EXPECT_EQ("accept_stat__", parameter_names[0][1]);
88  EXPECT_EQ("stepsize__", parameter_names[0][2]);
89  EXPECT_EQ("int_time__", parameter_names[0][3]);
90  EXPECT_EQ("energy__", parameter_names[0][4]);
91  EXPECT_EQ("x", parameter_names[0][5]);
92  EXPECT_EQ("y", parameter_names[0][6]);
93 
94  // Expect one name per parameter value.
95  EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
96  EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());
97 
98  EXPECT_EQ((num_warmup+num_samples)/num_thin, parameter_values.size());
99 
100  // Expect one call to set parameter names, and one set of output per
101  // iteration.
102  EXPECT_EQ("lp__", diagnostic_names[0][0]);
103  EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
104 }
105 
107  unsigned int random_seed = 0;
108  unsigned int chain = 1;
109  double init_radius = 0;
110  int num_warmup = 200;
111  int num_samples = 400;
112  int num_thin = 5;
113  bool save_warmup = true;
114  int refresh = 0;
115  double stepsize = 0.1;
116  double stepsize_jitter = 0;
117  double int_time = 8;
119  EXPECT_EQ(interrupt.call_count(), 0);
120 
122  model, context, random_seed, chain, init_radius,
123  num_warmup, num_samples, num_thin, save_warmup, refresh,
124  stepsize, stepsize_jitter, int_time,
125  interrupt, logger, init,
127 
128  std::vector<std::vector<std::string> > parameter_names;
129  parameter_names = parameter.vector_string_values();
130  std::vector<std::vector<double> > parameter_values;
131  parameter_values = parameter.vector_double_values();
132  std::vector<std::vector<std::string> > diagnostic_names;
133  diagnostic_names = diagnostic.vector_string_values();
134  std::vector<std::vector<double> > diagnostic_values;
135  diagnostic_values = diagnostic.vector_double_values();
136 
137  EXPECT_EQ(return_code, 0);
138 }
139 
141  unsigned int random_seed = 0;
142  unsigned int chain = 1;
143  double init_radius = 0;
144  int num_warmup = 200;
145  int num_samples = 400;
146  int num_thin = 5;
147  bool save_warmup = true;
148  int refresh = 0;
149  double stepsize = 0.1;
150  double stepsize_jitter = 0;
151  double int_time = 8;
153  EXPECT_EQ(interrupt.call_count(), 0);
154 
155 
157  model, context, random_seed, chain, init_radius,
158  num_warmup, num_samples, num_thin, save_warmup, refresh,
159  stepsize, stepsize_jitter, int_time,
160  interrupt, logger, init,
162 
163  std::vector<std::string> init_values;
164  init_values = init.string_values();
165 
166  EXPECT_EQ(0, init_values.size());
167 
168  EXPECT_EQ(1, logger.find_info("Elapsed Time:"));
169  EXPECT_EQ(1, logger.find_info("seconds (Warm-up)"));
170  EXPECT_EQ(1, logger.find_info("seconds (Sampling)"));
171  EXPECT_EQ(1, logger.find_info("seconds (Total)"));
172  EXPECT_EQ(0, logger.call_count_error());
173 }
std::vector< std::vector< std::string > > vector_string_values()
stan::io::empty_var_context context
stan::test::unit::instrumented_writer init
TEST_F(ServicesSampleHmcStaticDiagE, call_count)
chain
Check that an output directory exists.
stan::test::unit::instrumented_writer diagnostic
const XML_Char * context
Definition: expat.h:434
stan::test::unit::instrumented_writer parameter
int hmc_static_diag_e(Model &model, stan::io::var_context &init, stan::io::var_context &init_inv_metric, unsigned int random_seed, unsigned int chain, double init_radius, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, double int_time, callbacks::interrupt &interrupt, callbacks::logger &logger, callbacks::writer &init_writer, callbacks::writer &sample_writer, callbacks::writer &diagnostic_writer)
std::vector< std::vector< double > > vector_double_values()
stan::test::unit::instrumented_logger logger
unsigned int find_info(const std::string &msg)
const XML_Char XML_Content * model
Definition: expat.h:151
void refresh()
Definition: show_event.C:21