DecisionTree.cxx
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 /// \file DecisionTree.cxx
3 /// \brief Decision Tree PID
4 /// \author Christopher Backhouse - bckhouse@caltech.edu
5 ////////////////////////////////////////////////////////////////////////
6 
8 
9 #include "Utilities/func/MathUtil.h"
10 
11 #include "TFile.h"
12 #include "TRandom3.h"
13 #include "TTree.h"
14 #include "TVectorD.h"
15 
16 #include <algorithm>
17 #include <cassert>
18 #include <cmath>
19 #include <map>
20 #include <iostream>
21 #include <unistd.h>
22 #include <pthread.h>
23 #include <random>
24 
25 namespace lem
26 {
27  namespace dec
28  {
29  // For serializing to/from a TTree. Anonymous namespace, no-one else needs
30  // to see these
31  namespace
32  {
33  int gThis;
34  int gLeft, gRight;
35  int gCutDim;
36  double gCutVal;
37  double gSig, gBkg;
38  // Maps memory addresses to rows in the TTree, allowing serialization of pointers
39  std::map<const TreeNode*, int> gPtrs;
40  }
41 
42  //......................................................................
44  {
45  for(Tree* t: fTrees) delete t;
46  }
47 
48  //......................................................................
49  double Forest::Classify(const Evt& evt) const
50  {
51  // Average all the constituent trees
52  double mean = 0;
53  for(Tree* t: fTrees) mean += t->Classify(evt);
54  mean /= fTrees.size();
55  return mean;
56  }
57 
58  //......................................................................
60  {
61  TFile f(fname.c_str());
62  assert(!f.IsZombie());
63  TTree* tr = (TTree*)f.Get("tree");
64  assert(tr);
65 
66  tr->SetBranchAddress("this", &gThis);
67  tr->SetBranchAddress("left", &gLeft);
68  tr->SetBranchAddress("right", &gRight);
69  tr->SetBranchAddress("cutdim", &gCutDim);
70  tr->SetBranchAddress("cutval", &gCutVal);
71  tr->SetBranchAddress("sig", &gSig);
72  tr->SetBranchAddress("bkg", &gBkg);
73 
74  // Map from TTree rows to pointers, use to connect up nodes
75  std::map<int, TreeNode*> ptrs;
76 
77  // For every row in the tree
78  const int N = tr->GetEntries();
79  for(int n = 0; n < N; ++n){
80  tr->GetEntry(n);
81 
82  if(gLeft == -1){
83  assert(gRight == -1);
84 
85  // If it has no children it's a simple Leaf
86  ptrs[gThis] = new Leaf(gSig, gBkg);
87  }
88  else{
89  assert(gRight != -1);
90 
91  // If it does have children, it's a Cut. We should have seen its
92  // children already, so look them up in the map.
93 
94  TreeNode* left = ptrs[gLeft];
95  TreeNode* right = ptrs[gRight];
96  assert(left && right);
97 
98  // Only retain unparented nodes in the map
99  ptrs.erase(ptrs.find(gLeft));
100  ptrs.erase(ptrs.find(gRight));
101 
102  ptrs[gThis] = new Cut(gCutDim, gCutVal, left, right);
103  }
104  }
105 
106  // Anything that never found a parent is a tree root
107  std::cout << "Loaded " << ptrs.size() << " trees" << std::endl;
108  Forest* ret = new Forest;
109 
110  int matIdx = 0;
111  // Fallback in case file is old
112  TMatrixD unit(kNumPIDVars+1, kNumPIDVars+1);
113  unit.UnitMatrix();
114 
115  for(auto it: ptrs){
116  TMatrixD* mat = (TMatrixD*)f.Get(TString::Format("mat_%d", matIdx++));
117  ret->fTrees.push_back(new Tree(it.second, mat ? *mat : unit));
118  }
119 
120  return ret;
121  }
122 
123  //......................................................................
125  {
126  TFile fout(fname.c_str(), "RECREATE");
127  TTree tr("tree", "tree");
128 
129  // The ToTree functions of the individual nodes require us to have done this setup
130  tr.Branch("this", &gThis);
131  tr.Branch("left", &gLeft);
132  tr.Branch("right", &gRight);
133  tr.Branch("cutdim", &gCutDim);
134  tr.Branch("cutval", &gCutVal);
135  tr.Branch("sig", &gSig);
136  tr.Branch("bkg", &gBkg);
137 
138  // And they maintain the state of gPtrs
139  gPtrs.clear();
140 
141  for(unsigned int i = 0; i < fTrees.size(); ++i){
142  fTrees[i]->ToTree(&tr);
143  fTrees[i]->GetMatrix().Write(TString::Format("mat_%u", i));
144  }
145 
146  tr.Write("tree");
147  }
148 
149  //......................................................................
151  {
152  // Start with the identity matrix
154  ret.UnitMatrix();
155 
156  TRandom3 r(0); // auto-seed
157 
158  // Random rotation by all the Euler angles
159  for(int i = 0; i < kNumPIDVars-1; ++i){
160  for(int j = i+1; j < kNumPIDVars; ++j){
161  TMatrixD mul(kNumPIDVars+1, kNumPIDVars+1);
162  mul.UnitMatrix();
163 
164  double ang = r.Uniform(2*M_PI);
165  mul(i, i) = mul(j, j) = cos(ang);
166  mul(i, j) = +sin(ang);
167  mul(j, i) = -sin(ang);
168 
169  ret = mul*ret;
170  }
171  }
172 
173  return ret;
174  }
175 
176  //......................................................................
178  {
180  for(int i = 0; i < kNumPIDVars; ++i) vec[i] = evt.vars[i];
181  vec[kNumPIDVars] = 1;
182  vec = mat*vec;
183 
184  Evt ret = evt;
185  for(int i = 0; i < kNumPIDVars; ++i) ret.vars[i] = vec[i];
186  return ret;
187  }
188 
189  // Boilerplate required to call thread function with the correct arguments
190  struct ThreadArgs
191  {
192  const std::vector<Evt>* evts;
194  };
195 
196  void* thread_func(void* a)
197  {
198  ThreadArgs* args = (ThreadArgs*)a;
199  return Forest::TrainSingle(*args->evts, args->scale, args->trans);
200  }
201 
202  //......................................................................
203  Forest Forest::Train(std::vector<Evt>& trainEvts, unsigned int nTrees,
204  bool parallel)
205  {
206  Forest ret;
207 
208  const unsigned int N = trainEvts.size();
209 
210  // Because we only take half the sample each time below
211  for(unsigned int n = 0; n < N; ++n) trainEvts[n].weight *= 2;
212 
213  // Accumulate the mean and RMS for each variable
214  double mean[kNumPIDVars] = {0,};
215  double rms[kNumPIDVars] = {0,};
216  double W = 0;
217  for(unsigned int n = 0; n < N; ++n){
218  const Evt& e = trainEvts[n];
219  for(int i = 0; i < kNumPIDVars; ++i){
220  if(fabs(e.vars[i]) < 10){ // There are some crazy values in there
221  mean[i] += e.weight*e.vars[i];
222  rms[i] += e.weight*util::sqr(e.vars[i]);
223  }
224  }
225  W += e.weight;
226  }
227 
228  // Translation matrix. To learn how this works, read
229  // http://en.wikipedia.org/wiki/Translation_matrix
230  TMatrixD transMat(kNumPIDVars+1, kNumPIDVars+1);
231  transMat.UnitMatrix();
232 
233  // Matrix to normalize coordinates into sigmas
234  TMatrixD scaleMat(kNumPIDVars+1, kNumPIDVars+1);
235  scaleMat(kNumPIDVars, kNumPIDVars) = 1;
236 
237  for(int i = 0; i < kNumPIDVars; ++i){
238  mean[i] /= W;
239  rms[i] = rms[i]/W-util::sqr(mean[i]);
240 
241  transMat(i, kNumPIDVars) = -mean[i];
242  scaleMat(i, i) = 1/sqrt(rms[i]);
243  }
244 
245 
246  if(parallel){
247  // Somewhat oversubscribe the number of cores, to keep them all busy
248  const int nThreads = sysconf(_SC_NPROCESSORS_ONLN)*1.5;
249 
250  std::vector<pthread_t*> ths;
251  int live = 0; // counter how many threads are currently running
252 
253  ThreadArgs args = {&trainEvts, scaleMat, transMat};
254 
255  while(true){
256  // Start new threads if needed
257  while(live < nThreads && ret.fTrees.size()+live < nTrees){
258  ths.push_back(new pthread_t);
259  pthread_create(ths.back(), 0, &thread_func, &args);
260  ++live;
261  }
262 
263  // Break when we're done
264  if(live == 0 && ret.fTrees.size() >= nTrees) break;
265 
266  // Have to busy-wait because sleep() suspends all the threads
267 
268  // Check all the threads to see if they've finished
269  for(unsigned int i = 0; i < ths.size(); ++i){
270  void* tr;
271 #ifndef DARWINBUILD
272  if(ths[i] && pthread_tryjoin_np(*ths[i], &tr) == 0){
273 #else
274  std::cerr << "Building on OSX requires using pthread_join rather than pthread_join_np" << std::endl;
275  if(ths[i] && pthread_join(*ths[i], &tr) == 0){
276 #endif
277  // If so, clean them up and store their results
278  delete ths[i];
279  ths[i] = 0;
280  ret.fTrees.push_back((Tree*)tr);
281  --live;
282  std::cout << "." << std::flush;
283  }
284  } // end for i
285  } // end while
286  std::cout << std::endl;
287  }
288  else{
289  for(unsigned int n = 0; n < nTrees; ++n){
290  ret.fTrees.push_back(TrainSingle(trainEvts, scaleMat, transMat));
291  std::cout << "." << std::flush;
292  }
293  std::cout << std::endl;
294  }
295 
296  // Try to put things back as we found them
297  for(unsigned int n = 0; n < N; ++n) trainEvts[n].weight /= 2;
298 
299  return ret;
300  }
301 
302  //......................................................................
303  Tree* Forest::TrainSingle(const std::vector<Evt>& trainEvts,
304  const TMatrixD& scaleMat,
305  const TMatrixD& transMat)
306  {
307  // To avoid linear artifacts in the output, we rotate into a random
308  // basis. Note that this is different for each tree in the forest.
309  const TMatrixD rotMat = RandomOrthoMatrix();
310 
311  // Before rotating, center and normalize the variables so they're all
312  // treated on an equal footing.
313  const TMatrixD mat = rotMat * scaleMat * transMat;
314 
315  const unsigned int N = trainEvts.size();
316 
317  std::vector<Evt> transformedEvts;
318  for(unsigned int i = 0; i < N; ++i){
319  transformedEvts.push_back(TransformEvent(trainEvts[i], mat));
320  }
321 
322  // Shuffle the training events and take the first half. This ensures
323  // each tree gets a different sample, so on average they don't
324  // overtrain.
325  std::random_device rng;
326  std::mt19937 urng(rng());
327  std::shuffle(transformedEvts.begin(), transformedEvts.end(), urng);
328 
329  // The training algorithm requires multiple lists of the events, sorted
330  // according to each PID variable.
331  std::vector<std::list<Evt*>> sorted(kNumPIDVars);
332  for(int v = 0; v < kNumPIDVars; ++v){
333  for(unsigned int i = 0; i < N/2; ++i){
334  sorted[v].push_back(&transformedEvts[i]);
335  }
336 
337  sorted[v].sort([v](const Evt* a, const Evt* b)
338  {
339  return a->vars[v] < b->vars[v];
340  });
341  }
342 
343  // Train the tree and add it to the forest
344  TreeNode* tree = TrainSingleTransformed(sorted, 0);
345  std::vector<Evt> test(transformedEvts.begin()+N/2, transformedEvts.end());
346  tree = tree->Prune(test);
347  return new Tree(tree, mat);
348  }
349 
350  /// Don't build a tree deeper than this
351  const unsigned int kDepthLimit = 20;
352  /// Once there are less than this many events in a subtree, stop subdividing
353  const unsigned int kMinBucketSize = 20;
354 
355  //......................................................................
357  TrainSingleTransformed(std::vector<std::list<Evt*> >& sorted,
358  unsigned int depth)
359  {
360  const unsigned int N = sorted[0].size();
361 
362  double totSig = 0;
363  double totBkg = 0;
364 
365  for(Evt* evt : sorted[0]){
366  const double w = evt->weight;
367  if(evt->isSig) totSig += w; else totBkg += w;
368  }
369 
370  // Variety of reasons not to try to cut anymore
371  if(depth == kDepthLimit || N < kMinBucketSize || totSig == 0 || totBkg == 0)
372  return new Leaf(totSig, totBkg);
373 
374  double bestFOMSq = 0;
375  // Convention that the invalid iterator is taken from the first list
376  std::list<Evt*>::iterator bestCutIt = sorted[0].end();
377  int bestCutDim = 0;
378 
379  // Try a cut in every dimension
380  for(int dim = 0; dim < kNumPIDVars; ++dim){
381  // Counts below the cut
382  double s1 = 0;
383  double b1 = 0;
384  for(std::list<Evt*>::iterator it = sorted[dim].begin(); it != sorted[dim].end(); ++it){
385  // Keep the counts below and above up to date
386  const double w = (*it)->weight;
387  if((*it)->isSig) s1 += w; else b1 += w;
388 
389  const double s2 = totSig-s1;
390  const double b2 = totBkg-b1;
391 
392  if(s1+b1 == 0 || s2+b2 == 0) continue;
393 
394  // Figure of merit if we were to cut here
395  const double fomSq = util::sqr(s1)/(s1+b1)+util::sqr(s2)/(s2+b2);
396 
397  if(fomSq > bestFOMSq){
398  bestFOMSq = fomSq;
399  if(it != sorted[dim].begin()){
400  // Definition of the cut is the last iterator in the "left" sample
401  bestCutIt = it;
402  --bestCutIt;
403  }
404  bestCutDim = dim;
405  }
406  } // end for n
407  } // end for dim
408 
409  // Didn't find a cut, or it's no better
410  if(bestCutIt == sorted[0].end() ||
411  bestFOMSq <= util::sqr(totSig)/(totSig+totBkg))
412  return new Leaf(totSig, totBkg);
413 
414  // First iterator in the "right" sample
415  std::list<Evt*>::iterator next = bestCutIt;
416  ++next;
417 
418  // Cut is halfway between the two points we want to seperate
419  const double cutPos = ((*bestCutIt)->vars[bestCutDim]+(*next)->vars[bestCutDim])/2;
420 
421  // Children require the same sorted lists that we got
422  std::vector<std::list<Evt*> > sortedLeft(kNumPIDVars);
423  std::vector<std::list<Evt*> > sortedRight(kNumPIDVars);
424 
425  for(int v = 0; v < kNumPIDVars; ++v){
426  for(Evt* evt : sorted[v]){
427  if(evt->vars[bestCutDim] < cutPos)
428  sortedLeft[v].push_back(evt);
429  else
430  sortedRight[v].push_back(evt);
431  }
432 
433  // Make sure to clear lists out once we don't need them, save memory
434  // while recursing.
435  sorted[v].clear();
436  }
437  sorted.clear();
438 
439  TreeNode* left = TrainSingleTransformed(sortedLeft, depth+1);
440  TreeNode* right = TrainSingleTransformed(sortedRight, depth+1);
441  return new Cut(bestCutDim, cutPos, left, right);
442  }
443 
444  //......................................................................
445  double Tree::Classify(const Evt& e) const
446  {
447  return fHead->Classify(TransformEvent(e, fMatrix));
448  }
449 
450  //......................................................................
451  Cut::Cut(int d, double v, TreeNode* l, TreeNode* r)
452  : fCutDim(d), fCutVal(v), fLeft(l), fRight(r)
453  {
454  assert(fLeft->NSig()+fRight->NSig() > 0);
455  assert(fLeft->NBkg()+fRight->NBkg() > 0);
456  }
457 
458  //......................................................................
459  double Cut::FOM() const
460  {
461  // FOM is the sum in quadrature of our children
462  return sqrt(util::sqr(fLeft->FOM())+util::sqr(fRight->FOM()));
463  }
464 
465  //......................................................................
466  double Cut::Classify(const Evt& e) const
467  {
468  if(e.vars[fCutDim] < fCutVal){
469  return fLeft->Classify(e);
470  }
471  else{
472  return fRight->Classify(e);
473  }
474  }
475 
476  //......................................................................
477  double Cut::NSig() const
478  {
479  return fLeft->NSig()+fRight->NSig();
480  }
481 
482  //......................................................................
483  double Cut::NBkg() const
484  {
485  return fLeft->NBkg()+fRight->NBkg();
486  }
487 
488  //......................................................................
489  TreeNode* Cut::Prune(std::vector<Evt>& evts)
490  {
491  std::vector<Evt> leftEvts, rightEvts;
492  const unsigned int N = evts.size();
493  double sigL = 0, bkgL = 0, sigR = 0, bkgR = 0;
494  for(unsigned int n = 0; n < N; ++n){
495  const Evt& e = evts[n];
496  if(e.vars[fCutDim] < fCutVal){
497  leftEvts.push_back(e);
498  if(e.isSig) sigL += e.weight; else bkgL += e.weight;
499  }
500  else{
501  rightEvts.push_back(e);
502  if(e.isSig) sigR += e.weight; else bkgR += e.weight;
503  }
504  }
505  evts.clear();
506 
507  fLeft = fLeft->Prune(leftEvts);
508  fRight = fRight->Prune(rightEvts);
509 
510  // If originally trained with higher s/b on the right but test sample has
511  // it higher on the left, or vice versa, then overtrained, merge the
512  // trees.
513  const bool trainSign = fLeft->NSig()*fRight->NBkg() < fRight->NSig()*fLeft->NBkg();
514  const bool testSign = sigL*bkgR < sigR*bkgL;
515 
516  if(fLeft->IsLeaf() && fRight->IsLeaf() && testSign != trainSign){
517  TreeNode* ret = new Leaf(fLeft->NSig()+fRight->NSig(),
518  fLeft->NBkg()+fRight->NBkg());
519  delete this;
520  return ret;
521  }
522 
523  return this;
524  }
525 
526  //......................................................................
527  void Cut::ToTree(TTree* tr) const
528  {
529  // Recurse first so there'll be entries in gPtrs to look left and right up in
530  fLeft->ToTree(tr);
531  fRight->ToTree(tr);
532 
533  gPtrs[this] = tr->GetEntries();
534 
535  gThis = gPtrs[this];
536  gLeft = gPtrs[fLeft];
537  gRight = gPtrs[fRight];
538  gCutDim = fCutDim;
539  gCutVal = fCutVal;
540  gSig = -1;
541  gBkg = -1;
542 
543  tr->Fill();
544  }
545 
546 
547  //......................................................................
548  Leaf::Leaf(double s, double b) : fSig(s), fBkg(b)
549  {
550  }
551 
552  //......................................................................
553  double Leaf::FOM() const
554  {
555  if(fSig == 0) return 0;
556  return fSig/sqrt(fSig+fBkg);
557  }
558 
559  //......................................................................
560  double Leaf::Classify(const Evt&) const
561  {
562  if(fSig == 0) return 0;
563  return fSig/(fSig+fBkg);
564  }
565 
566  //......................................................................
567  void Leaf::ToTree(TTree* tr) const
568  {
569  gPtrs[this] = tr->GetEntries();
570 
571  gThis = gPtrs[this];
572  gLeft = -1;
573  gRight = -1;
574  gCutDim = -1;
575  gCutVal = -1;
576  gSig = fSig;
577  gBkg = fBkg;
578 
579  tr->Fill();
580  }
581 
582  } // namespace dec
583 } // namespace lem
584 
Leaf of a decision tree. No further cuts are made
Definition: DecisionTree.h:150
virtual void ToTree(TTree *tr) const =0
Serialize to a TTree. Needs assistance from Forest::ToTree.
Decision Tree PID.
constexpr auto const & right(const_AssnsIter< L, R, D, Dir > const &a, const_AssnsIter< L, R, D, Dir > const &b)
Definition: AssnsIter.h:112
void ToFile(const std::string &fname)
Write out PID structure to a file.
const int kNumPIDVars
Definition: DecisionTree.h:23
set< int >::iterator it
virtual double FOM() const =0
Estimated figure of merit of this subtree.
fvar< T > fabs(const fvar< T > &x)
Definition: fabs.hpp:15
const Var weight
static Forest * FromFile(const std::string &fname)
Load PID from a file.
virtual double Classify(const Evt &) const
Calculate PID value of e.
static TreeNode * TrainSingleTransformed(std::vector< std::list< Evt * > > &sorted, unsigned int depth)
Internal helper: train one (sub)tree.
T sqrt(T number)
Definition: d0nt_math.hpp:156
OStream cerr
Definition: OStream.cxx:7
virtual bool IsLeaf() const =0
Evt TransformEvent(const Evt &evt, const TMatrixD &mat)
double fSig
How many signal events reached this node in training.
Definition: DecisionTree.h:165
TreeNode * fLeft
Definition: DecisionTree.h:146
T sqr(T x)
More efficient square function than pow(x,2)
Definition: MathUtil.h:23
double fBkg
How many background events reached this node in training.
Definition: DecisionTree.h:166
#define M_PI
Definition: SbMath.h:34
std::vector< Tree * > fTrees
Definition: DecisionTree.h:79
Leaf(double s, double b)
A training or trial event for the decision tree.
Definition: DecisionTree.h:28
friend void * thread_func(void *)
PID
Definition: FillPIDs.h:14
double Classify(const Evt &evt) const
Calculate the PID value for evt.
const XML_Char * s
Definition: expat.h:262
void Cut(double x)
Definition: plot_outliers.C:1
Double_t scale
Definition: plot.C:25
virtual void ToTree(TTree *tr) const
Serialize to a TTree. Needs assistance from Forest::ToTree.
virtual double NBkg() const
Number of background events below this point in the tree.
virtual double FOM() const
Estimated figure of merit of this subtree.
virtual double FOM() const
Estimated figure of merit of this subtree.
int fCutDim
What variable the cut is on.
Definition: DecisionTree.h:144
static TMatrixD RandomOrthoMatrix()
Internal helper. Generate a random set of orthogonal unit vectors.
const double a
"Random forest" of decision trees
Definition: DecisionTree.h:42
int evt
const std::vector< Evt > * evts
Float_t d
Definition: plot.C:236
const int nThreads
Definition: PhotonSim_mp.C:69
const double j
Definition: BetheBloch.cxx:29
virtual double NSig() const
Number of signal events below this point in the tree.
Eigen::VectorXd vec
virtual double Classify(const Evt &e) const
Calculate PID value of e.
virtual double NSig() const =0
Number of signal events below this point in the tree.
OStream cout
Definition: OStream.cxx:6
const unsigned int kMinBucketSize
Once there are less than this many events in a subtree, stop subdividing.
Cut(int d, double v, TreeNode *l, TreeNode *r)
TreeNode * fRight
Subtrees to descend into based on cut result.
Definition: DecisionTree.h:146
static Forest Train(std::vector< Evt > &trainEvts, unsigned int nTrees, bool parallel=false)
Initial training of the PID.
constexpr auto const & left(const_AssnsIter< L, R, D, Dir > const &a, const_AssnsIter< L, R, D, Dir > const &b)
Definition: AssnsIter.h:104
T sin(T number)
Definition: d0nt_math.hpp:132
double Classify(const Evt &e) const
Calculate PID value of e.
Abstract base class for tree cuts and leaves.
Definition: DecisionTree.h:83
virtual double NBkg() const =0
Number of background events below this point in the tree.
double vars[kNumPIDVars]
Definition: DecisionTree.h:30
const hit & b
Definition: hits.cxx:21
T cos(T number)
Definition: d0nt_math.hpp:78
assert(nhit_max >=nhit_nbins)
TRandom3 r(0)
A DecisionTree. Forwards most things to the head TreeNode.
Definition: DecisionTree.h:105
virtual void ToTree(TTree *tr) const
Serialize to a TTree. Needs assistance from Forest::ToTree.
virtual double Classify(const Evt &e) const =0
Calculate PID value of e.
void Format(TGraph *gr, int lcol, int lsty, int lwid, int mcol, int msty, double msiz)
Definition: Style.cxx:154
static Tree * TrainSingle(const std::vector< Evt > &trainEvts, const TMatrixD &scaleMat, const TMatrixD &transMat)
virtual TreeNode * Prune(std::vector< Evt > &evts)
Float_t e
Definition: plot.C:35
float fSig[xbins]
Definition: MakePlots.C:86
Float_t w
Definition: plot.C:20
#define W(x)
void next()
Definition: show_event.C:84
const unsigned int kDepthLimit
Don&#39;t build a tree deeper than this.
double fCutVal
Where the cut is placed.
Definition: DecisionTree.h:145
Eigen::MatrixXd mat
virtual TreeNode * Prune(std::vector< Evt > &evts)=0
enum BeamMode string