produceTreeForValidationProng.py
Go to the documentation of this file.
1 import sys
2 import caffe
3 import matplotlib
4 import numpy as np
5 import lmdb
6 import argparse
7 import leveldb
8 import ROOT
9 import h5py
10 import eventLabelToLepton as eltl
11 from collections import defaultdict
12 
13 if __name__ == "__main__":
14  parser = argparse.ArgumentParser()
15  parser.add_argument('--proto', type=str, required=True)
16  parser.add_argument('--model', type=str, required=True)
17  parser.add_argument('--leveldb', type=str, required=True)
18  parser.add_argument('--hdf5', type=str, required=False)
19  parser.add_argument('--output', type=str, required=False)
20  args = parser.parse_args()
21 
22  count = 0
23  correct = 0
24  matrix = defaultdict(int) # (real,pred) -> int
25  labels_set = set()
26 
27  net = caffe.Net(args.proto, args.model, caffe.TEST)
28  caffe.set_mode_gpu()
29  db = leveldb.LevelDB(args.leveldb)
30 
31  if args.hdf5 is not None:
32  fileNameH = args.hdf5+".0.h5"
33  h_f = h5py.File(fileNameH, "r")
34  a_group_key = h_f.keys()[0]
35  h_data = list(h_f[a_group_key])
36  h_size=len(h_data)
37  h_f.close()
38 
39  #lmdb_env = lmdb.open(args.lmdb)
40  #lmdb_txn = lmdb_env.begin()
41  #lmdb_cursor = lmdb_txn.cursor()
42 
43  t = ROOT.TTree( 't1', 'tree with histos' )
44  if args.output is None:
45  fout = ROOT.TFile("output.root", "RECREATE")
46  else:
47  print("Output name: %s" % args.output)
48  fout = ROOT.TFile("%s.root" % args.output, "RECREATE")
49  fout.cd()
50 
51  arrpidProton=np.zeros(1, dtype=float)
52  arrpidPion=np.zeros(1, dtype=float)
53  arrpidGamma=np.zeros(1, dtype=float)
54  arrpidMuon=np.zeros(1, dtype=float)
55  arrpidElectron=np.zeros(1, dtype=float)
56  arrpidPiZero=np.zeros(1, dtype=float)
57  arrpidNeutron=np.zeros(1, dtype=float)
58 
59  truelabel = np.zeros(1, dtype=int)
60  truelabelall = np.zeros(1, dtype=int)
61  selectedlabel = np.zeros(1, dtype=int)
62 
63  nhit = np.zeros(1, dtype=int)
64 
65  tVertX = np.zeros(1, dtype=float)
66  tVertY = np.zeros(1, dtype=float)
67  tVertZ = np.zeros(1, dtype=float)
68 
69  t.Branch( 'truelabelall', truelabelall, 'truelabelall/I' )
70  t.Branch( 'truelabel', truelabel, 'truelabel/I' )
71  t.Branch( 'selectedlabel', selectedlabel, 'selectedlabel/I')
72  t.Branch( 'nhit', nhit, 'nhit/I' )
73  t.Branch( 'tvertx', tVertX, 'tvertx/D' )
74  t.Branch( 'tverty', tVertY, 'tverty/D' )
75  t.Branch( 'tvertz', tVertZ, 'tvertz/D' )
76  t.Branch( 'proton', arrpidProton, 'proton/D' )
77  t.Branch( 'pion', arrpidPion, 'pion/D' )
78  t.Branch( 'gamma', arrpidGamma, 'gamma/D' )
79  t.Branch( 'muon', arrpidMuon, 'muon/D' )
80  t.Branch( 'electron', arrpidElectron, 'electron/D' )
81  t.Branch( 'pizero', arrpidPiZero, 'pizero/D' )
82  t.Branch( 'neutron', arrpidNeutron, 'neutron/D' )
83  #how many events do you want to run over?
84  nEvents=5000000
85  h_count=0
86  fileCountH=0
87 
88  for key, value in db.RangeIter():
89  if count > nEvents:
90  break
91 
92  datum = caffe.proto.caffe_pb2.Datum()
93  datum.ParseFromString(value)
94  label = int(datum.label)
95  truelabelall[0]=label
96  truelabel[0]=eltl.labelToLepton(label)
97 
98  image = caffe.io.datum_to_array(datum)
99  image = image.astype(np.uint8)
100 
101  nhit[0]=np.count_nonzero(np.asarray([image]))
102 
103  if args.hdf5 is not None:
104  if (count + 1) > h_size:
105  fileCountH=fileCountH+1
106  fileNameHadd = args.hdf5+"."+str(fileCountH)+".h5"
107  hadd_f = h5py.File(fileNameHadd, "r")
108  a_group_key_add = hadd_f.keys()[0]
109  h_data = list(hadd_f[a_group_key_add])
110  h_count=0
111  h_size=h_size+len(h_data)
112  hadd_f.close()
113 
114  tVertX[0]=h_data[h_count][0]
115  tVertY[0]=h_data[h_count][1]
116  tVertZ[0]=h_data[h_count][2]
117 
118  out = net.forward_all(data=np.asarray([image]))
119 
120  plabel = int(out['prob'][0].argmax(axis=0))
121  selectedlabel[0]=plabel
122 
123  arrpidProton[0]=out['prob'][0][2]
124  arrpidPion[0]=out['prob'][0][4]
125  arrpidGamma[0]=out['prob'][0][6]
126  arrpidMuon[0]=out['prob'][0][1]
127  arrpidElectron[0]=out['prob'][0][0]
128  arrpidNeutron[0]=out['prob'][0][3]
129  arrpidPiZero[0]=out['prob'][0][5]
130 
131  t.Fill()
132 
133  h_count = h_count + 1
134  count = count + 1
135 
136  iscorrect = label == plabel
137  correct = correct + (1 if iscorrect else 0)
138  matrix[(label, plabel)] += 1
139  labels_set.update([label, plabel])
140 
141  # if not iscorrect:
142  # print("\rError: key=%s, expected %i but predicted %i" \
143  # % (key, label, plabel))
144 
145  sys.stdout.write("\rAccuracy: %.1f%% Progress: %.2f%%" % (100.*correct/count,100.*count/nEvents))
146  sys.stdout.flush()
147 
148  print(str(correct) + " out of " + str(count) + " were classified correctly")
149 
150  fout.WriteTObject(t,"t")
151 
152 
153 
154 
155 
bool print