TFHandler.cxx
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////
2 // TFHander
3 //
4 // A module to evaluate a tensorflow model in ART.
5 //
6 // \author $Author: Micah Groh
7 ////////////////////////////////////////////////////////////////////////
8 
9 #include <iostream>
10 #include <string>
12 #include "tensorflow/core/public/session.h"
13 #include "tensorflow/core/platform/env.h"
14 #include "tensorflow/core/public/session_options.h"
15 
16 namespace tensorflow
17 {
18  TFHandler::TFHandler(std::string ModelPath, int CPUlimit):
19  fSession(0),
20  fModelPath(ModelPath)
21  {
22  Initialize(CPUlimit);
23  }
24 
26  {
27  if(fSession) delete fSession;
28  }
29 
30  void TFHandler::Initialize(int CPUlimit)
31  {
32  // Initalize a session and create a frozen tensorflow graph
33  GraphDef graph_def;
34  // Force tf to only use a specified number of cores
35  tensorflow::SessionOptions options;
36  tensorflow::ConfigProto &config = options.config;
37  if (CPUlimit > 0) {
38  config.set_inter_op_parallelism_threads(CPUlimit);
39  config.set_intra_op_parallelism_threads(CPUlimit);
40  config.set_use_per_session_threads(false);
41  }
42 
43  Status status = NewSession(options, &fSession);
44  if (!status.ok()) {
45  std::cout<<"Error when making tensorflow session: "<< status.ToString() << std::endl;
46  return;
47  }
48 
49  status = ReadBinaryProto(Env::Default(),fModelPath, &graph_def);
50  if (!status.ok()){
51  std::cout<<"Error when reading model file: "<< status.ToString() << std::endl;
52  return;
53  }
54 
55  status = fSession->Create(graph_def);
56  if (!status.ok()){
57  std::cout<<"Error when creating tensorflow graph: "<< status.ToString() << std::endl;
58  return;
59  }
60 
61  std::cout<<"Successfully loaded tensorflow graph."<<std::endl;
62  }
63 
64  std::vector<Tensor> TFHandler::Predict(std::vector<std::pair<std::string,Tensor>> inputs,
65  std::vector<std::string> outputLabels)
66  {
67  std::vector<Tensor> outputs;
68 
69  // Go go go!
70  Status status = fSession->Run(inputs,outputLabels,{}, &outputs);
71 
72  if (!status.ok()){
73  std::cout<<"Error when running session: "<< status.ToString() << std::endl;
74  }
75 
76  return outputs;
77  }
78 
79 }
int status
Definition: fabricate.py:1613
TFHandler(std::string model, int CPUlimit=1)
Basic constructor, takes path to model pb.
Definition: TFHandler.cxx:18
std::string fModelPath
Definition: TFHandler.h:32
Definition: config.py:1
void Initialize(int CPUlimit)
Definition: TFHandler.cxx:30
std::vector< Tensor > Predict(std::vector< std::pair< std::string, Tensor >> inputs, std::vector< std::string > outputLabels)
Definition: TFHandler.cxx:64
OStream cout
Definition: OStream.cxx:6
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
Session * fSession
Definition: TFHandler.h:31