Public Member Functions | Private Attributes | List of all members
tensorflow::TFHandler Class Reference

Wrapper for Tensorflow which handles construction and prediction. More...

#include "/cvmfs/nova-development.opensciencegrid.org/novasoft/releases/N21-05-05/TensorFlowHandler/TFHandler.h"

Public Member Functions

 TFHandler (std::string model, int CPUlimit=1)
 Basic constructor, takes path to model pb. More...
 
 ~TFHandler ()
 
void Initialize (int CPUlimit)
 
std::vector< Tensor > Predict (std::vector< std::pair< std::string, Tensor >> inputs, std::vector< std::string > outputLabels)
 

Private Attributes

Session * fSession
 
std::string fModelPath
 

Detailed Description

Wrapper for Tensorflow which handles construction and prediction.

Definition at line 19 of file TFHandler.h.

Constructor & Destructor Documentation

tensorflow::TFHandler::TFHandler ( std::string  model,
int  CPUlimit = 1 
)

Basic constructor, takes path to model pb.

Definition at line 18 of file TFHandler.cxx.

References Initialize().

18  :
19  fSession(0),
20  fModelPath(ModelPath)
21  {
22  Initialize(CPUlimit);
23  }
std::string fModelPath
Definition: TFHandler.h:32
void Initialize(int CPUlimit)
Definition: TFHandler.cxx:30
Session * fSession
Definition: TFHandler.h:31
tensorflow::TFHandler::~TFHandler ( )

Definition at line 25 of file TFHandler.cxx.

References fSession.

26  {
27  if(fSession) delete fSession;
28  }
Session * fSession
Definition: TFHandler.h:31

Member Function Documentation

void tensorflow::TFHandler::Initialize ( int  CPUlimit)

Definition at line 30 of file TFHandler.cxx.

References om::cout, allTimeWatchdog::endl, fModelPath, fSession, fabricate::options, and fabricate::status.

Referenced by TFHandler().

31  {
32  // Initalize a session and create a frozen tensorflow graph
33  GraphDef graph_def;
34  // Force tf to only use a specified number of cores
35  tensorflow::SessionOptions options;
36  tensorflow::ConfigProto &config = options.config;
37  if (CPUlimit > 0) {
38  config.set_inter_op_parallelism_threads(CPUlimit);
39  config.set_intra_op_parallelism_threads(CPUlimit);
40  config.set_use_per_session_threads(false);
41  }
42 
43  Status status = NewSession(options, &fSession);
44  if (!status.ok()) {
45  std::cout<<"Error when making tensorflow session: "<< status.ToString() << std::endl;
46  return;
47  }
48 
49  status = ReadBinaryProto(Env::Default(),fModelPath, &graph_def);
50  if (!status.ok()){
51  std::cout<<"Error when reading model file: "<< status.ToString() << std::endl;
52  return;
53  }
54 
55  status = fSession->Create(graph_def);
56  if (!status.ok()){
57  std::cout<<"Error when creating tensorflow graph: "<< status.ToString() << std::endl;
58  return;
59  }
60 
61  std::cout<<"Successfully loaded tensorflow graph."<<std::endl;
62  }
int status
Definition: fabricate.py:1613
std::string fModelPath
Definition: TFHandler.h:32
Definition: config.py:1
OStream cout
Definition: OStream.cxx:6
Session * fSession
Definition: TFHandler.h:31
std::vector< Tensor > tensorflow::TFHandler::Predict ( std::vector< std::pair< std::string, Tensor >>  inputs,
std::vector< std::string outputLabels 
)

Definition at line 64 of file TFHandler.cxx.

References om::cout, allTimeWatchdog::endl, fSession, dumpEventsToText::inputs, dumpEventsToText::outputs, and fabricate::status.

Referenced by cvnneutronprongtf::CVNNeutronProngTF::CalcResult(), cvntf::CVNTF::produce(), cvntf::CVNCosmicTF::produce(), regcvntf::RegCVNTF::produce(), nuonecvntf::NuonECVNTF::produce(), cvnprongtf::CVNProngTF::produce(), cvneventtf::CVNEventTF::produce(), cvntf::CVNProngEvaluatorTF::produce(), and nerd::NERDEval::run_graph().

66  {
67  std::vector<Tensor> outputs;
68 
69  // Go go go!
70  Status status = fSession->Run(inputs,outputLabels,{}, &outputs);
71 
72  if (!status.ok()){
73  std::cout<<"Error when running session: "<< status.ToString() << std::endl;
74  }
75 
76  return outputs;
77  }
int status
Definition: fabricate.py:1613
OStream cout
Definition: OStream.cxx:6
Session * fSession
Definition: TFHandler.h:31

Member Data Documentation

std::string tensorflow::TFHandler::fModelPath
private

Definition at line 32 of file TFHandler.h.

Referenced by Initialize().

Session* tensorflow::TFHandler::fSession
private

Definition at line 31 of file TFHandler.h.

Referenced by Initialize(), Predict(), and ~TFHandler().


The documentation for this class was generated from the following files: