2 #include <gtest/gtest.h> 4 #include <test/test-models/good/optimization/rosenbrock.hpp> 21 std::vector<std::string> names_;
22 std::vector<std::vector<double> > states_;
43 states_.push_back(state);
53 parameter(parameter_ss),
66 unsigned int seed = 0;
67 unsigned int chain = 1;
68 double init_radius = 0;
70 bool save_iterations =
true;
75 seed, chain, init_radius,
84 save_iterations, refresh,
90 EXPECT_EQ(logger.call_count(), logger.call_count_info()) <<
"all output to info";
91 EXPECT_EQ(1, logger.find(
"Initial log joint probability = -1"));
92 EXPECT_EQ(1, logger.find(
"Optimization terminated normally: "));
93 EXPECT_EQ(1, logger.find(
" Convergence detected: relative gradient magnitude is below tolerance"));
95 EXPECT_EQ(
"0,0\n", init_ss.str());
97 ASSERT_EQ(3, parameter.names_.size());
98 EXPECT_EQ(
"lp__", parameter.names_[0]);
99 EXPECT_EQ(
"x", parameter.names_[1]);
100 EXPECT_EQ(
"y", parameter.names_[2]);
102 EXPECT_EQ(23, parameter.states_.size());
103 EXPECT_FLOAT_EQ(0, parameter.states_.front()[1])
104 <<
"initial value should be (0, 0)";
105 EXPECT_FLOAT_EQ(0, parameter.states_.front()[2])
106 <<
"initial value should be (0, 0)";
107 EXPECT_FLOAT_EQ(0.99998301, parameter.states_.back()[1])
108 <<
"optimal value should be (1, 1)";
109 EXPECT_FLOAT_EQ(0.99996597, parameter.states_.back()[2])
110 <<
"optimal value should be (1, 1)";
111 EXPECT_FLOAT_EQ(return_code, 0);
112 EXPECT_EQ(22, callback.
n);
void operator()(const std::vector< double > &state)
stan::test::unit::instrumented_logger logger
TEST_F(ServicesOptimizeLbfgs, rosenbrock)
def callbacks(model_name, group, tensorboard=True)
chain
Check that an output directory exists.
stan::io::empty_var_context context
std::stringstream parameter_ss
values(std::ostream &stream)
stan::callbacks::stream_writer init
void operator()(const std::vector< std::string > &names)
const XML_Char XML_Content * model
int lbfgs(Model &model, stan::io::var_context &init, unsigned int random_seed, unsigned int chain, double init_radius, int history_size, double init_alpha, double tol_obj, double tol_rel_obj, double tol_grad, double tol_rel_grad, double tol_param, int num_iterations, bool save_iterations, int refresh, callbacks::interrupt &interrupt, callbacks::logger &logger, callbacks::writer &init_writer, callbacks::writer ¶meter_writer)