CVNEventTF_module.cc
Go to the documentation of this file.
1 #include<iostream>
2 
13 #include "fhiclcpp/ParameterSet.h"
15 
16 
17 #include "RecoBase/Cluster.h"
18 #include "RecoBase/Prong.h"
19 #include "RecoBase/PID.h"
20 #include "SummaryData/SpillData.h"
21 #include "CVN/func/PixelMap.h"
22 #include "CVN/func/Result.h"
24 #include "CVN/func/AssignLabels.h"
25 #include "CVN/func/TrainingData.h"
28 #include "CVN/func/ProngType.h"
30 #include "Utilities/AssociationUtil.h"
32 
34 
35 #include <memory>
36 
37 namespace cvneventtf{
38  class CVNEventTF : public art::EDProducer{
39  public:
40  explicit CVNEventTF(fhicl::ParameterSet const &pset);
41  virtual ~CVNEventTF();
42 
43  void produce(art::Event& evt);
44  bool IsRHC(const art::Event &evt);
45  tensorflow::Tensor vector_to_tensor(std::vector<unsigned char>);
47 
48  protected:
53  bool fUseGeV;
57  unsigned int fNOutput;
60  unsigned int fCPUlimit;
63  };
64 }
65 
66 namespace cvneventtf{
68  fSliceLabel (pset.get<std::string>("SliceLabel")),
69  fPixelMapInput(pset.get<std::string>("PixelMapInput")),
70  fGeneratorLabel (pset.get<std::string>("GeneratorLabel")),
71  fNuMILabel (pset.get<std::string>("NuMILabel")),
72  fUseGeV (pset.get<bool> ("UseGeV")),
73  fLibPath (pset.get<std::string>("LibPath")),
74  fModelFHCName (pset.get<std::string>("ModelFHCName")),
75  fModelRHCName (pset.get<std::string>("ModelRHCName")),
76  fNOutput (pset.get<unsigned int>("NOutput")),
77  fInputName (pset.get<std::string>("InputName")),
78  fOutputName (pset.get<std::string>("OutputName")),
79  fCPUlimit (pset.get<unsigned int>("CPUlimit")),
80  fTFFHC(0),
81  fTFRHC(0)
82  {
84  produces< std::vector<cvn::Result> >();
85  produces< art::Assns<cvn::Result,rb::Cluster> >();
86  }
87 
89  {
90  if(fTFFHC) delete fTFFHC;
91  if(fTFRHC) delete fTFRHC;
92  }
93 
95  {
97  if (!evt.isRealData())
98  evt.getByLabel(fGeneratorLabel, spillPot);
99  else
100  evt.getByLabel(fNuMILabel, spillPot);
101 
102  if (spillPot.failedToGet())
103  {
104  mf::LogError("CVNEventTF") <<
105  "Spill Data not found, aborting without horn current information";
106  abort();
107  }
108 
109  return spillPot->isRHC;
110  }
111 
113  {
114 
115  if (IsRHC(evt)) {
116  if (!fTFRHC)
118  return fTFRHC;
119  }
120  else {
121  if (!fTFFHC)
123  return fTFFHC;
124  }
125  }
126 
127  tensorflow::Tensor CVNEventTF::vector_to_tensor(std::vector<unsigned char> pm)
128  {
129  const unsigned int vectorSize = pm.size();
130 
131  // Initialize the tensors
132  tensorflow::Tensor tensor(tensorflow::DT_FLOAT, {1, vectorSize});
133  auto rel = tensor.tensor<float,2>();
134 
135  // Loop over each element
136  for(unsigned int i = 0; i < vectorSize; ++i) rel(0, i) = pm[i];
137 
138  return tensor;
139  }
140 
142  {
143 
144  tensorflow::TFHandler* fTF = GetModel(evt);
145  //Containers for things we're gonna produce
146  std::unique_ptr< std::vector<cvn::Result> >
147  resultCol(new std::vector<cvn::Result>);
148  std::unique_ptr< art::Assns<cvn::Result, rb::Cluster> >
149  assocresult(new art::Assns<cvn::Result, rb::Cluster>);
150 
151  // Get slices
153  evt.getByLabel(fSliceLabel, slicecol);
154  art::PtrVector<rb::Cluster> slicelist;
155  for(unsigned int i = 0; i < slicecol->size(); ++i){
156  slicelist.push_back(art::Ptr<rb::Cluster>(slicecol, i));
157  }
158 
159  // Get pixel maps
160  art::FindManyP<cvn::PixelMap> fmPixelMap(slicecol, evt, fPixelMapInput);
161 
162  //loop over slices
163  for(size_t iClust = 0; iClust < slicelist.size(); ++iClust) {
164  if(!fmPixelMap.isValid()) continue;
165  if(slicelist[iClust]->IsNoise()) continue;
166 
167  const std::vector<art::Ptr<cvn::PixelMap> > pixelMaps = fmPixelMap.at(iClust);
168 
169  if(pixelMaps.empty()) continue;
170 
171  std::vector<unsigned char> pmslice = (*pixelMaps[0]).PixelMapToVector(fUseGeV);
172 
173  tensorflow::Tensor tensor = vector_to_tensor(pmslice);
174 
175  std::vector<tensorflow::Tensor> result = fTF->Predict({{fInputName,tensor}},
176  {fOutputName});
177  auto tfoutput = result[0].tensor<float,2>();
178 
179  float resultvec[fNOutput];
180 
181  for(unsigned int i = 0; i<fNOutput; i++){
182  resultvec[i] = (float)tfoutput(0,i);
183  }
184  const float* output = resultvec;
185 
186  resultCol->emplace_back(output, fNOutput);
187 
188  util::CreateAssn(*this, evt, *(resultCol.get()),
189  slicelist[iClust], *(assocresult.get()), UINT_MAX);
190 
191  } // slices
192 
193  evt.put(std::move(resultCol));
194  evt.put(std::move(assocresult));
195 
196  } // produce
197 }
198 
bool isRHC
is the beam in antineutrino mode, aka RHC
Definition: SpillData.h:28
ofstream output
bool IsRHC(const art::Event &evt)
static bool CreateAssn(art::EDProducer const &prod, art::Event &evt, std::vector< T > &a, art::Ptr< U > b, art::Assns< T, U > &assn, size_t indx=UINT_MAX, std::string const &instance=std::string())
Create a 1 to 1 association between a new product and one already in the event.
pdg code and pid value
std::string EnvExpansion(const std::string &inString)
Function to expand environment variables.
Definition: EnvExpand.cxx:8
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
bool isRealData() const
Definition: Event.h:83
DEFINE_ART_MODULE(TestTMapFile)
void produce(art::Event &evt)
PixelMap for CVN.
std::vector< Tensor > Predict(std::vector< std::pair< std::string, Tensor >> inputs, std::vector< std::string > outputLabels)
Definition: TFHandler.cxx:64
tensorflow::Tensor vector_to_tensor(std::vector< unsigned char >)
tensorflow::TFHandler * fTFFHC
ProductID put(std::unique_ptr< PROD > &&product)
Definition: Event.h:102
Result for CVN.
void push_back(Ptr< U > const &p)
Definition: PtrVector.h:441
string rel
Definition: shutoffs.py:11
size_type size() const
Definition: PtrVector.h:308
tensorflow::TFHandler * GetModel(const art::Event &evt)
::xsd::cxx::tree::string< char, simple_type > string
Definition: Database.h:154
CVNEventTF(fhicl::ParameterSet const &pset)
bool getByLabel(std::string const &label, std::string const &productInstanceName, Handle< PROD > &result) const
Definition: DataViewImpl.h:344
tensorflow::TFHandler * fTFRHC
Wrapper for Tensorflow which handles construction and prediction.
Definition: TFHandler.h:19
bool failedToGet() const
Definition: Handle.h:196