SimpleModel.cxx
Go to the documentation of this file.
2 
3 #include <utility>
4 #include "tensorflow/core/platform/env.h"
6 
7 
8 tensorflow::Tensor SimpleModel::getFakeProngTensor(
9  const std::vector<std::string> &vars, float fillValue
10 )
11 {
12  tensorflow::Tensor result(
13  tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1, int(vars.size())})
14  );
15 
16  auto resultData = result.tensor<float, 3>();
17 
18  for (size_t idx = 0; idx < vars.size(); idx++) {
19  resultData(0, 0, int(idx)) = fillValue;
20  }
21 
22  return result;
23 }
24 
25 tensorflow::Tensor SimpleModel::getSliceTensor(
26  const std::unordered_map<std::string, double> &varMap,
27  const std::vector<std::string> &vars
28 )
29 {
30  tensorflow::Tensor result(
31  tensorflow::DT_FLOAT, tensorflow::TensorShape({1, int(vars.size())})
32  );
33  auto resultData = result.tensor<float, 2>();
34 
35  for (size_t idx = 0; idx < vars.size(); idx++) {
36  resultData(0, int(idx)) = varMap.at(vars[idx]);
37  }
38 
39  return result;
40 }
41 
42 tensorflow::Tensor SimpleModel::getProngTensor(
43  const std::unordered_map<std::string, std::vector<double>> &varMap,
44  const std::vector<std::string> &vars
45 )
46 {
47  if (vars.empty()) {
48  return getFakeProngTensor(vars, 0.0);
49  }
50 
51  const size_t nProngs = (varMap.empty()) ? 0 : varMap.at(vars[0]).size();
52 
53  if (nProngs == 0) {
54  /*
55  * NOTE: Fake tensor with nProngs == 1 is needed, since otherwise
56  * tensorflow fails to infer graph dimensions.
57  */
58  return getFakeProngTensor(vars, 0.0);
59  }
60 
61  tensorflow::Tensor result(
62  tensorflow::DT_FLOAT,
63  tensorflow::TensorShape({1, int(nProngs), int(vars.size())})
64  );
65  auto resultData = result.tensor<float, 3>();
66 
67  for (size_t idx = 0; idx < vars.size(); idx++)
68  {
69  const auto &values = varMap.at(vars[idx]);
70 
71  if (values.size() != nProngs) {
72  throw std::runtime_error("Prongs have different lengths");
73  }
74 
75  for (size_t pngIdx = 0; pngIdx < nProngs; pngIdx++) {
76  resultData(0, pngIdx, int(idx)) = values[pngIdx];
77  }
78  }
79 
80  return result;
81 }
82 
83 void SimpleModel::init() const
84 {
85  if (initialized) {
86  return;
87  }
88 
89  config.load();
90  tfHandler = std::make_shared<tensorflow::TFHandler>(config.getModelPath());
91 
92  initialized = true;
93 }
94 
96  const std::string &savedir,
97  const std::vector<InputConfigKeys> &scalarInputKeys,
98  const std::vector<InputConfigKeys> &vectorInputKeys,
99  const std::vector<std::string> &outputKeys
100 ) : config(savedir, scalarInputKeys, vectorInputKeys, outputKeys),
101  tfHandler(nullptr),
102  initialized(false)
103 { }
104 
105 std::vector<tensorflow::Tensor> SimpleModel::predict(const VarDict &varDict)
106  const
107 {
108  init();
109 
110  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
111  inputs.reserve(config.scalarInputs.size() + config.vectorInputs.size());
112 
113  for (const auto &inputConfig : config.scalarInputs)
114  {
115  inputs.emplace_back(std::make_pair(
116  inputConfig.nodeName,
117  getSliceTensor(varDict.scalarVarMap, inputConfig.varNames)
118  ));
119  }
120 
121  for (const auto &inputConfig : config.vectorInputs)
122  {
123  inputs.emplace_back(std::make_pair(
124  inputConfig.nodeName,
125  getProngTensor(varDict.vectorVarMap, inputConfig.varNames)
126  ));
127  }
128 
129  return tfHandler->Predict(inputs, config.outputNodes);
130 }
131 
std::unordered_map< std::string, std::vector< double > > vectorVarMap
Definition: VarDict.h:11
void init() const
Definition: SimpleModel.cxx:83
Definition: config.py:1
std::pair< Spectrum *, CheatDecomp * > make_pair(SpectrumLoaderBase &loader_data, SpectrumLoaderBase &loader_mc, HistAxis *axis, Cut *cut, const SystShifts &shift, const Var &wei)
Definition: DataMCLoad.C:336
static tensorflow::Tensor getSliceTensor(const std::unordered_map< std::string, double > &varMap, const std::vector< std::string > &vars)
Definition: SimpleModel.cxx:25
static tensorflow::Tensor getProngTensor(const std::unordered_map< std::string, std::vector< double >> &varMap, const std::vector< std::string > &vars)
Definition: SimpleModel.cxx:42
static tensorflow::Tensor getFakeProngTensor(const std::vector< std::string > &vars, float fillValue=0.0)
Definition: SimpleModel.cxx:8
const std::map< std::pair< std::string, std::string >, Variable > vars
bool initialized
Definition: SimpleModel.h:18
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
std::shared_ptr< tensorflow::TFHandler > tfHandler
Definition: SimpleModel.h:17
SimpleModel(const std::string &savedir, const std::vector< InputConfigKeys > &scalarInputKeys, const std::vector< InputConfigKeys > &vectorInputKeys, const std::vector< std::string > &outputKeys)
Definition: SimpleModel.cxx:95
Definition: VarDict.h:7
std::vector< tensorflow::Tensor > predict(const VarDict &varDict) const
std::unordered_map< std::string, double > scalarVarMap
Definition: VarDict.h:10