Model.cxx
Go to the documentation of this file.
2 
3 #include "tensorflow/core/platform/env.h"
4 #include "tensorflow/core/public/session.h"
5 
6 namespace SliceLID
7 {
8 
9 const std::vector<InputConfigKeys> Model::sliceInputs({
10  { "input_slice", "vars_slice" }
11 });
12 
13 const std::vector<InputConfigKeys> Model::prongInputs({
14  { "input_png3d", "vars_png3d" }
15 });
16 const std::vector<std::string> Model::outputs({ "target" });
17 
18 
19 Model::Model(const std::string &savedir)
20  : simpleModel(savedir, sliceInputs, prongInputs, outputs)
21 { }
22 
24 {
25  std::vector<tensorflow::Tensor> outputs = simpleModel.predict(varDict);
26 
27  auto outputData = outputs[0].tensor<float, 2>();
29 
30  result.nc = outputData(0, 0);
31  result.nue = outputData(0, 1);
32  result.numu = outputData(0, 2);
33  result.nutau = outputData(0, 3);
34  result.cosmic = outputData(0, 4);
35 
36  return result;
37 }
38 
39 }
Prediction predict(const VarDict &varDict)
Definition: Model.cxx:23
Model(const std::string &savedir)
Definition: Model.cxx:19
static const std::vector< InputConfigKeys > prongInputs
Definition: Model.h:21
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
SimpleModel simpleModel
Definition: Model.h:18
static const std::vector< std::string > outputs
Definition: Model.h:22
static const std::vector< InputConfigKeys > sliceInputs
Definition: Model.h:20
Definition: VarDict.h:7
std::vector< tensorflow::Tensor > predict(const VarDict &varDict) const