CVNTF_module.cc
Go to the documentation of this file.
1 #include<iostream>
2 
10 #include "fhiclcpp/ParameterSet.h"
12 
13 #include "RecoBase/Cluster.h"
15 #include "CVN/func/PixelMap.h"
16 
18 
19 #include <memory>
20 
21 namespace cvntf{
22  class CVNTF : public art::EDProducer{
23  public:
24  explicit CVNTF(fhicl::ParameterSet const &pset);
25  virtual ~CVNTF();
26 
27  void produce(art::Event& evt);
28  tensorflow::Tensor vector_to_tensor(std::vector<unsigned char>);
29 
30  protected:
33  bool fUseGeV;
35 
37  };
38 }
39 
40 namespace cvntf{
42  fSliceLabel (pset.get<std::string>("SliceLabel")),
43  fPixelMapInput(pset.get<std::string>("PixelMapInput")),
44  fUseGeV (pset.get<bool> ("UseGeV")),
45  fModelPath (pset.get<std::string>("ModelPath")),
46  fTF(0)
47  {
49  }
50 
52  {
53  if(fTF) delete fTF;
54  }
55 
56  tensorflow::Tensor CVNTF::vector_to_tensor(std::vector<unsigned char> pm)
57  {
58  const unsigned int vectorSize = pm.size();
59 
60  // Initialize the tensors
61  tensorflow::Tensor tensor(tensorflow::DT_FLOAT, {1, vectorSize});
62  auto rel = tensor.tensor<float,2>();
63 
64  // Loop over each element
65  for(unsigned int i = 0; i < vectorSize; ++i) rel(0, i) = pm[i];
66 
67  return tensor;
68  }
69 
71  {
72  // Get slices
74  evt.getByLabel(fSliceLabel, slicecol);
76  for(unsigned int i = 0; i < slicecol->size(); ++i){
77  slicelist.push_back(art::Ptr<rb::Cluster>(slicecol, i));
78  }
79 
80  // Get pixel maps
81  art::FindManyP<cvn::PixelMap> fmPixelMap(slicecol, evt, fPixelMapInput);
82 
83  //loop over slices
84  for(size_t iClust = 0; iClust < slicelist.size(); ++iClust) {
85  if(!fmPixelMap.isValid()) continue;
86  if(slicelist[iClust]->IsNoise()) continue;
87 
88  const std::vector<art::Ptr<cvn::PixelMap> > pixelMaps = fmPixelMap.at(iClust);
89  if(pixelMaps.empty()) continue;
90 
91  // Fill the pixel map array for this slice
92  std::vector<unsigned char> pm = (*pixelMaps[0]).PixelMapToVector(fUseGeV);
93 
94  // Convert to format expected by the network
95  tensorflow::Tensor tensor = vector_to_tensor(pm);
96 
97  std::vector<tensorflow::Tensor> result = fTF->Predict({{"input",tensor}},
98  {"output_out"});
99 
100  auto tfoutput = result[0].tensor<float,2>();
101 
102  std::cout<<"Graph Output: "<<std::endl;
103  std::cout<<tfoutput(0,0)<<std::endl;
104  std::cout<<tfoutput(0,1)<<std::endl;
105  std::cout<<tfoutput(0,2)<<std::endl;
106  std::cout<<tfoutput(0,3)<<std::endl;
107  std::cout<<tfoutput(0,4)<<std::endl;
108  } // slices
109  } // produce
110 }
111 
std::string fSliceLabel
Definition: CVNTF_module.cc:31
std::string fModelPath
Definition: CVNTF_module.cc:34
tensorflow::Tensor vector_to_tensor(std::vector< unsigned char >)
Definition: CVNTF_module.cc:56
DEFINE_ART_MODULE(TestTMapFile)
PixelMap for CVN.
std::vector< Tensor > Predict(std::vector< std::pair< std::string, Tensor >> inputs, std::vector< std::string > outputLabels)
Definition: TFHandler.cxx:64
virtual ~CVNTF()
Definition: CVNTF_module.cc:51
void push_back(Ptr< U > const &p)
Definition: PtrVector.h:441
void produce(art::Event &evt)
Definition: CVNTF_module.cc:70
int evt
tensorflow::TFHandler * fTF
Definition: CVNTF_module.cc:36
string rel
Definition: shutoffs.py:11
size_type size() const
Definition: PtrVector.h:308
OStream cout
Definition: OStream.cxx:6
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
std::string fPixelMapInput
Definition: CVNTF_module.cc:32
CVNTF(fhicl::ParameterSet const &pset)
Definition: CVNTF_module.cc:41
bool getByLabel(std::string const &label, std::string const &productInstanceName, Handle< PROD > &result) const
Definition: DataViewImpl.h:344
Wrapper for Tensorflow which handles construction and prediction.
Definition: TFHandler.h:19