produceTreeForValidation.py
Go to the documentation of this file.
1 import sys
2 import caffe
3 import matplotlib
4 import numpy as np
5 import lmdb
6 import argparse
7 import leveldb
8 import ROOT
9 import h5py
10 import eventLabelToLepton as eltl
11 from collections import defaultdict
12 
13 if __name__ == "__main__":
14  parser = argparse.ArgumentParser()
15  parser.add_argument('--proto', type=str, required=True)
16  parser.add_argument('--model', type=str, required=True)
17  parser.add_argument('--leveldb', type=str, required=True)
18  parser.add_argument('--hdf5', type=str, required=False)
19  parser.add_argument('--output', type=str, required=False)
20  args = parser.parse_args()
21 
22  count = 0
23  correct = 0
24  matrix = defaultdict(int) # (real,pred) -> int
25  labels_set = set()
26 
27  net = caffe.Net(args.proto, args.model, caffe.TEST)
28  caffe.set_mode_gpu()
29  db = leveldb.LevelDB(args.leveldb)
30 
31  if args.hdf5 is not None:
32  fileNameH = args.hdf5+".0.h5"
33  h_f = h5py.File(fileNameH, "r")
34  a_group_key = h_f.keys()[0]
35  h_data = list(h_f[a_group_key])
36  h_size=len(h_data)
37  h_f.close()
38 
39  #lmdb_env = lmdb.open(args.lmdb)
40  #lmdb_txn = lmdb_env.begin()
41  #lmdb_cursor = lmdb_txn.cursor()
42 
43  t = ROOT.TTree( 't1', 'tree with histos' )
44  if args.output is None:
45  fout = ROOT.TFile("output.root", "RECREATE")
46  else:
47  print("Output name: %s" % args.output)
48  fout = ROOT.TFile("%s.root" % args.output, "RECREATE")
49  fout.cd()
50 
51  arrpidNuMu=np.zeros(1, dtype=float)
52  arrpidNue=np.zeros(1, dtype=float)
53  arrpidNuTau=np.zeros(1, dtype=float)
54  arrpidNC=np.zeros(1, dtype=float)
55  arrpidCosmic=np.zeros(1, dtype=float)
56 
57  truelabel = np.zeros(1, dtype=int)
58 
59  nhit = np.zeros(1, dtype=int)
60 
61  tVertX = np.zeros(1, dtype=float)
62  tVertY = np.zeros(1, dtype=float)
63  tVertZ = np.zeros(1, dtype=float)
64 
65  t.Branch( 'truelabel', truelabel, 'truelabel/I' )
66  t.Branch( 'nhit', nhit, 'nhit/I' )
67  t.Branch( 'tvertx', tVertX, 'tvertx/D' )
68  t.Branch( 'tverty', tVertY, 'tverty/D' )
69  t.Branch( 'tvertz', tVertZ, 'tvertz/D' )
70  t.Branch( 'numu', arrpidNuMu, 'numu/D' )
71  t.Branch( 'nue', arrpidNue, 'nue/D' )
72  t.Branch( 'nutau', arrpidNuTau, 'nutau/D' )
73  t.Branch( 'nc', arrpidNC, 'nc/D' )
74  t.Branch( 'cosmic', arrpidCosmic, 'cosmic/D' )
75 
76  #how many events do you want to run over?
77  nEvents=100000
78  h_count=0
79  fileCountH=0
80 
81  for key, value in db.RangeIter():
82  if count > nEvents:
83  break
84 
85  datum = caffe.proto.caffe_pb2.Datum()
86  datum.ParseFromString(value)
87  label = int(datum.label)
88  truelabel[0]=eltl.labelToLepton(label)
89 
90  image = caffe.io.datum_to_array(datum)
91  image = image.astype(np.uint8)
92 
93  nhit[0]=np.count_nonzero(np.asarray([image]))
94 
95  if args.hdf5 is not None:
96  if (count + 1) > h_size:
97  fileCountH=fileCountH+1
98  fileNameHadd = args.hdf5+"."+str(fileCountH)+".h5"
99  hadd_f = h5py.File(fileNameHadd, "r")
100  a_group_key_add = hadd_f.keys()[0]
101  h_data = list(hadd_f[a_group_key_add])
102  h_count=0
103  h_size=h_size+len(h_data)
104  hadd_f.close()
105 
106  tVertX[0]=h_data[h_count][0]
107  tVertY[0]=h_data[h_count][1]
108  tVertZ[0]=h_data[h_count][2]
109 
110  out = net.forward_all(data=np.asarray([image]))
111 
112  plabel = int(out['prob'][0].argmax(axis=0))
113 
114  lepSum=np.zeros(5, dtype=float)
115 
116  for entry in range(0, 391):
117  lepton=eltl.labelToLepton(entry)
118  lepSum[lepton]=lepSum[lepton]+(out['prob'][0][entry])
119 
120  arrpidNuMu[0]=float(lepSum[0])
121  arrpidNue[0]=float(lepSum[1])
122  arrpidNuTau[0]=float(lepSum[2])
123  arrpidNC[0]=float(lepSum[3])
124  arrpidCosmic[0]=float(lepSum[4])
125 
126  t.Fill()
127 
128  h_count = h_count + 1
129  count = count + 1
130 
131  iscorrect = label == plabel
132  correct = correct + (1 if iscorrect else 0)
133  matrix[(label, plabel)] += 1
134  labels_set.update([label, plabel])
135 
136  # if not iscorrect:
137  # print("\rError: key=%s, expected %i but predicted %i" \
138  # % (key, label, plabel))
139 
140  sys.stdout.write("\rAccuracy: %.1f%% Progress: %.2f%%" % (100.*correct/count,100.*count/nEvents))
141  sys.stdout.flush()
142 
143  print(str(correct) + " out of " + str(count) + " were classified correctly")
144 
145  fout.WriteTObject(t,"t")
146 
147 
148 
149 
150 
bool print