hmc_nuts_diag_e_test.cpp
Go to the documentation of this file.
1 #include <stan/math/prim/mat.hpp>
3 #include <gtest/gtest.h>
5 #include <test/test-models/good/optimization/rosenbrock.hpp>
7 #include <iostream>
8 
9 class ServicesSampleHmcNutsDiagE : public testing::Test {
10 public:
12  : model(context, &model_log) {}
13 
14  std::stringstream model_log;
18  stan_model model;
19 };
20 
22  unsigned int random_seed = 0;
23  unsigned int chain = 1;
24  double init_radius = 0;
25  int num_warmup = 200;
26  int num_samples = 400;
27  int num_thin = 5;
28  bool save_warmup = true;
29  int refresh = 0;
30  double stepsize = 0.1;
31  double stepsize_jitter = 0;
32  int max_depth = 8;
34  EXPECT_EQ(interrupt.call_count(), 0);
35 
37  model, context, random_seed, chain, init_radius,
38  num_warmup, num_samples, num_thin, save_warmup, refresh,
39  stepsize, stepsize_jitter, max_depth,
40  interrupt, logger, init,
42 
43  EXPECT_EQ(0, return_code);
44 
45  int num_output_lines = (num_warmup+num_samples)/num_thin;
46  EXPECT_EQ(num_warmup+num_samples, interrupt.call_count());
47  EXPECT_EQ(1, parameter.call_count("vector_string"));
48  EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
49  EXPECT_EQ(1, diagnostic.call_count("vector_string"));
50  EXPECT_EQ(num_output_lines, diagnostic.call_count("vector_double"));
51 }
52 
53 
54 TEST_F(ServicesSampleHmcNutsDiagE, parameter_checks) {
55  unsigned int random_seed = 0;
56  unsigned int chain = 1;
57  double init_radius = 0;
58  int num_warmup = 200;
59  int num_samples = 400;
60  int num_thin = 5;
61  bool save_warmup = true;
62  int refresh = 0;
63  double stepsize = 0.1;
64  double stepsize_jitter = 0;
65  int max_depth = 8;
67  EXPECT_EQ(interrupt.call_count(), 0);
68 
70  model, context, random_seed, chain, init_radius,
71  num_warmup, num_samples, num_thin, save_warmup, refresh,
72  stepsize, stepsize_jitter, max_depth,
73  interrupt, logger, init,
75 
76  std::vector<std::vector<std::string> > parameter_names;
77  parameter_names = parameter.vector_string_values();
78  std::vector<std::vector<double> > parameter_values;
79  parameter_values = parameter.vector_double_values();
80  std::vector<std::vector<std::string> > diagnostic_names;
81  diagnostic_names = diagnostic.vector_string_values();
82  std::vector<std::vector<double> > diagnostic_values;
83  diagnostic_values = diagnostic.vector_double_values();
84 
85  // Expectations of parameter parameter names.
86  ASSERT_EQ(9, parameter_names[0].size());
87  EXPECT_EQ("lp__", parameter_names[0][0]);
88  EXPECT_EQ("accept_stat__", parameter_names[0][1]);
89  EXPECT_EQ("stepsize__", parameter_names[0][2]);
90  EXPECT_EQ("treedepth__", parameter_names[0][3]);
91  EXPECT_EQ("n_leapfrog__", parameter_names[0][4]);
92  EXPECT_EQ("divergent__", parameter_names[0][5]);
93  EXPECT_EQ("energy__", parameter_names[0][6]);
94  EXPECT_EQ("x", parameter_names[0][7]);
95  EXPECT_EQ("y", parameter_names[0][8]);
96 
97  // Expect one name per parameter value.
98  EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
99  EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());
100 
101  EXPECT_EQ((num_warmup+num_samples)/num_thin, parameter_values.size());
102 
103  // Expect one call to set parameter names, and one set of output per
104  // iteration.
105  EXPECT_EQ("lp__", diagnostic_names[0][0]);
106  EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
107 
108 }
109 
111  unsigned int random_seed = 0;
112  unsigned int chain = 1;
113  double init_radius = 0;
114  int num_warmup = 200;
115  int num_samples = 400;
116  int num_thin = 5;
117  bool save_warmup = true;
118  int refresh = 0;
119  double stepsize = 0.1;
120  double stepsize_jitter = 0;
121  int max_depth = 8;
123  EXPECT_EQ(interrupt.call_count(), 0);
124 
126  model, context, random_seed, chain, init_radius,
127  num_warmup, num_samples, num_thin, save_warmup, refresh,
128  stepsize, stepsize_jitter, max_depth,
129  interrupt, logger, init,
131 
132  std::vector<std::vector<std::string> > parameter_names;
133  parameter_names = parameter.vector_string_values();
134  std::vector<std::vector<double> > parameter_values;
135  parameter_values = parameter.vector_double_values();
136  std::vector<std::vector<std::string> > diagnostic_names;
137  diagnostic_names = diagnostic.vector_string_values();
138  std::vector<std::vector<double> > diagnostic_values;
139  diagnostic_values = diagnostic.vector_double_values();
140 
141  EXPECT_EQ(return_code, 0);
142 
143 }
144 
145 TEST_F(ServicesSampleHmcNutsDiagE, output_regression) {
146  unsigned int random_seed = 0;
147  unsigned int chain = 1;
148  double init_radius = 0;
149  int num_warmup = 200;
150  int num_samples = 400;
151  int num_thin = 5;
152  bool save_warmup = true;
153  int refresh = 0;
154  double stepsize = 0.1;
155  double stepsize_jitter = 0;
156  int max_depth = 8;
158  EXPECT_EQ(interrupt.call_count(), 0);
159 
160 
162  model, context, random_seed, chain, init_radius,
163  num_warmup, num_samples, num_thin, save_warmup, refresh,
164  stepsize, stepsize_jitter, max_depth,
165  interrupt, logger, init,
167 
168  std::vector<std::string> init_values;
169  init_values = init.string_values();
170 
171  EXPECT_EQ(0, init_values.size());
172 
173  EXPECT_EQ(1, logger.find_info("Elapsed Time:"));
174  EXPECT_EQ(1, logger.find_info("seconds (Warm-up)"));
175  EXPECT_EQ(1, logger.find_info("seconds (Sampling)"));
176  EXPECT_EQ(1, logger.find_info("seconds (Total)"));
177  EXPECT_EQ(0, logger.call_count_error());
178 }
std::vector< std::vector< std::string > > vector_string_values()
int hmc_nuts_diag_e(Model &model, stan::io::var_context &init, stan::io::var_context &init_inv_metric, unsigned int random_seed, unsigned int chain, double init_radius, int num_warmup, int num_samples, int num_thin, bool save_warmup, int refresh, double stepsize, double stepsize_jitter, int max_depth, callbacks::interrupt &interrupt, callbacks::logger &logger, callbacks::writer &init_writer, callbacks::writer &sample_writer, callbacks::writer &diagnostic_writer)
stan::test::unit::instrumented_writer diagnostic
stan::test::unit::instrumented_logger logger
TEST_F(ServicesSampleHmcNutsDiagE, call_count)
chain
Check that an output directory exists.
stan::test::unit::instrumented_writer init
stan::test::unit::instrumented_writer parameter
const XML_Char * context
Definition: expat.h:434
stan::io::empty_var_context context
std::vector< std::vector< double > > vector_double_values()
unsigned int find_info(const std::string &msg)
const XML_Char XML_Content * model
Definition: expat.h:151
void refresh()
Definition: show_event.C:21