Model.cxx
Go to the documentation of this file.
2 
3 #include "tensorflow/core/platform/env.h"
4 #include "tensorflow/core/public/session.h"
5 
6 namespace LSTME
7 {
8 
9 const std::vector<InputConfigKeys> Model::sliceInputs({
10  { "input_slice", "vars_slice" }
11 });
12 
13 const std::vector<InputConfigKeys> Model::prongInputs({
14  { "input_png2d", "vars_png2d" },
15  { "input_png3d", "vars_png3d" }
16 });
17 
18 const std::vector<std::string> Model::outputs({
19  "target_primary", "target_total"
20 });
21 
22 
23 Model::Model(const std::string &savedir)
24  : simpleModel(savedir, sliceInputs, prongInputs, outputs)
25 { }
26 
28 {
29  std::vector<tensorflow::Tensor> outputs = simpleModel.predict(varDict);
30 
31  const float primaryE = outputs[0].tensor<float,2>()(0, 0);
32  const float totalE = outputs[1].tensor<float,2>()(0, 0);
33 
34  return LSTMEnergy{ primaryE, totalE };
35 }
36 
37 }
static const std::vector< std::string > outputs
Definition: Model.h:22
SimpleModel simpleModel
Definition: Model.h:18
static const std::vector< InputConfigKeys > sliceInputs
Definition: Model.h:20
LSTMEnergy predict(const VarDict &varDict)
Definition: Model.cxx:27
Model(const std::string &savedir)
Definition: Model.cxx:23
static const std::vector< InputConfigKeys > prongInputs
Definition: Model.h:21
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
Definition: VarDict.h:7
std::vector< tensorflow::Tensor > predict(const VarDict &varDict) const