Public Member Functions | Static Private Member Functions | Private Attributes | List of all members
SimpleModel Class Reference

#include "/cvmfs/nova-development.opensciencegrid.org/novasoft/releases/N20-11-28/TensorFlowEvaluator/SliceLID/simple_model/SimpleModel.h"

Public Member Functions

 SimpleModel (const std::string &savedir, const std::vector< InputConfigKeys > &scalarInputKeys, const std::vector< InputConfigKeys > &vectorInputKeys, const std::vector< std::string > &outputKeys)
 
void init () const
 
std::vector< tensorflow::Tensor > predict (const VarDict &varDict) const
 

Static Private Member Functions

static tensorflow::Tensor getFakeProngTensor (const std::vector< std::string > &vars, float fillValue=0.0)
 
static tensorflow::Tensor getSliceTensor (const std::unordered_map< std::string, double > &varMap, const std::vector< std::string > &vars)
 
static tensorflow::Tensor getProngTensor (const std::unordered_map< std::string, std::vector< double >> &varMap, const std::vector< std::string > &vars)
 

Private Attributes

ModelConfig config
 
std::shared_ptr< tensorflow::TFHandlertfHandler
 
bool initialized
 

Detailed Description

Definition at line 13 of file SimpleModel.h.

Constructor & Destructor Documentation

SimpleModel::SimpleModel ( const std::string savedir,
const std::vector< InputConfigKeys > &  scalarInputKeys,
const std::vector< InputConfigKeys > &  vectorInputKeys,
const std::vector< std::string > &  outputKeys 
)

Definition at line 95 of file SimpleModel.cxx.

100  : config(savedir, scalarInputKeys, vectorInputKeys, outputKeys),
101  tfHandler(nullptr),
102  initialized(false)
103 { }
bool initialized
Definition: SimpleModel.h:18
std::shared_ptr< tensorflow::TFHandler > tfHandler
Definition: SimpleModel.h:17
ModelConfig config
Definition: SimpleModel.h:16

Member Function Documentation

tensorflow::Tensor SimpleModel::getFakeProngTensor ( const std::vector< std::string > &  vars,
float  fillValue = 0.0 
)
staticprivate

Definition at line 8 of file SimpleModel.cxx.

References check_time_usage::float, compare_h5_caf::idx, makeTrainCVSamples::int, and fillBadChanDBTables::result.

Referenced by getProngTensor().

11 {
12  tensorflow::Tensor result(
13  tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1, int(vars.size())})
14  );
15 
16  auto resultData = result.tensor<float, 3>();
17 
18  for (size_t idx = 0; idx < vars.size(); idx++) {
19  resultData(0, 0, int(idx)) = fillValue;
20  }
21 
22  return result;
23 }
const std::map< std::pair< std::string, std::string >, Variable > vars
tensorflow::Tensor SimpleModel::getProngTensor ( const std::unordered_map< std::string, std::vector< double >> &  varMap,
const std::vector< std::string > &  vars 
)
staticprivate

Definition at line 42 of file SimpleModel.cxx.

References check_time_usage::float, getFakeProngTensor(), compare_h5_caf::idx, makeTrainCVSamples::int, and fillBadChanDBTables::result.

Referenced by predict().

46 {
47  if (vars.empty()) {
48  return getFakeProngTensor(vars, 0.0);
49  }
50 
51  const size_t nProngs = (varMap.empty()) ? 0 : varMap.at(vars[0]).size();
52 
53  if (nProngs == 0) {
54  /*
55  * NOTE: Fake tensor with nProngs == 1 is needed, since otherwise
56  * tensorflow fails to infer graph dimensions.
57  */
58  return getFakeProngTensor(vars, 0.0);
59  }
60 
61  tensorflow::Tensor result(
62  tensorflow::DT_FLOAT,
63  tensorflow::TensorShape({1, int(nProngs), int(vars.size())})
64  );
65  auto resultData = result.tensor<float, 3>();
66 
67  for (size_t idx = 0; idx < vars.size(); idx++)
68  {
69  const auto &values = varMap.at(vars[idx]);
70 
71  if (values.size() != nProngs) {
72  throw std::runtime_error("Prongs have different lengths");
73  }
74 
75  for (size_t pngIdx = 0; pngIdx < nProngs; pngIdx++) {
76  resultData(0, pngIdx, int(idx)) = values[pngIdx];
77  }
78  }
79 
80  return result;
81 }
static tensorflow::Tensor getFakeProngTensor(const std::vector< std::string > &vars, float fillValue=0.0)
Definition: SimpleModel.cxx:8
const std::map< std::pair< std::string, std::string >, Variable > vars
tensorflow::Tensor SimpleModel::getSliceTensor ( const std::unordered_map< std::string, double > &  varMap,
const std::vector< std::string > &  vars 
)
staticprivate

Definition at line 25 of file SimpleModel.cxx.

References check_time_usage::float, compare_h5_caf::idx, makeTrainCVSamples::int, and fillBadChanDBTables::result.

Referenced by predict().

29 {
30  tensorflow::Tensor result(
31  tensorflow::DT_FLOAT, tensorflow::TensorShape({1, int(vars.size())})
32  );
33  auto resultData = result.tensor<float, 2>();
34 
35  for (size_t idx = 0; idx < vars.size(); idx++) {
36  resultData(0, int(idx)) = varMap.at(vars[idx]);
37  }
38 
39  return result;
40 }
const std::map< std::pair< std::string, std::string >, Variable > vars
void SimpleModel::init ( ) const

Definition at line 83 of file SimpleModel.cxx.

References initialized, and tfHandler.

Referenced by demo.App::__init__(), testem0.App::__init__(), Lesson1.App::__init__(), ExN03.App::__init__(), and predict().

84 {
85  if (initialized) {
86  return;
87  }
88 
89  config.load();
90  tfHandler = std::make_shared<tensorflow::TFHandler>(config.getModelPath());
91 
92  initialized = true;
93 }
Definition: config.py:1
bool initialized
Definition: SimpleModel.h:18
std::shared_ptr< tensorflow::TFHandler > tfHandler
Definition: SimpleModel.h:17
std::vector< tensorflow::Tensor > SimpleModel::predict ( const VarDict varDict) const

Definition at line 105 of file SimpleModel.cxx.

References getProngTensor(), getSliceTensor(), init(), dumpEventsToText::inputs, make_pair(), VarDict::scalarVarMap, tfHandler, and VarDict::vectorVarMap.

Referenced by LSTME::Model::predict(), and SliceLID::Model::predict().

107 {
108  init();
109 
110  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
111  inputs.reserve(config.scalarInputs.size() + config.vectorInputs.size());
112 
113  for (const auto &inputConfig : config.scalarInputs)
114  {
115  inputs.emplace_back(std::make_pair(
116  inputConfig.nodeName,
117  getSliceTensor(varDict.scalarVarMap, inputConfig.varNames)
118  ));
119  }
120 
121  for (const auto &inputConfig : config.vectorInputs)
122  {
123  inputs.emplace_back(std::make_pair(
124  inputConfig.nodeName,
125  getProngTensor(varDict.vectorVarMap, inputConfig.varNames)
126  ));
127  }
128 
129  return tfHandler->Predict(inputs, config.outputNodes);
130 }
std::unordered_map< std::string, std::vector< double > > vectorVarMap
Definition: VarDict.h:11
void init() const
Definition: SimpleModel.cxx:83
Definition: config.py:1
std::pair< Spectrum *, CheatDecomp * > make_pair(SpectrumLoaderBase &loader_data, SpectrumLoaderBase &loader_mc, HistAxis *axis, Cut *cut, const SystShifts &shift, const Var &wei)
Definition: DataMCLoad.C:336
static tensorflow::Tensor getSliceTensor(const std::unordered_map< std::string, double > &varMap, const std::vector< std::string > &vars)
Definition: SimpleModel.cxx:25
static tensorflow::Tensor getProngTensor(const std::unordered_map< std::string, std::vector< double >> &varMap, const std::vector< std::string > &vars)
Definition: SimpleModel.cxx:42
std::shared_ptr< tensorflow::TFHandler > tfHandler
Definition: SimpleModel.h:17
std::unordered_map< std::string, double > scalarVarMap
Definition: VarDict.h:10

Member Data Documentation

ModelConfig SimpleModel::config
mutableprivate

Definition at line 16 of file SimpleModel.h.

Referenced by Controller.Controller::make_output_directory().

bool SimpleModel::initialized
mutableprivate

Definition at line 18 of file SimpleModel.h.

Referenced by init().

std::shared_ptr<tensorflow::TFHandler> SimpleModel::tfHandler
mutableprivate

Definition at line 17 of file SimpleModel.h.

Referenced by init(), and predict().


The documentation for this class was generated from the following files: