Public Member Functions | Static Public Member Functions | Protected Member Functions | Static Protected Member Functions | Protected Attributes | Friends | List of all members
lem::dec::Forest Class Reference

"Random forest" of decision trees More...

#include "/cvmfs/nova-development.opensciencegrid.org/novasoft/releases/N21-04-15/LEM/func/DecisionTree.h"

Public Member Functions

 ~Forest ()
 
double Classify (const Evt &evt) const
 Calculate the PID value for evt. More...
 
void ToFile (const std::string &fname)
 Write out PID structure to a file. More...
 

Static Public Member Functions

static ForestFromFile (const std::string &fname)
 Load PID from a file. More...
 
static Forest Train (std::vector< Evt > &trainEvts, unsigned int nTrees, bool parallel=false)
 Initial training of the PID. More...
 

Protected Member Functions

 Forest ()
 

Static Protected Member Functions

static TreeTrainSingle (const std::vector< Evt > &trainEvts, const TMatrixD &scaleMat, const TMatrixD &transMat)
 
static TreeNodeTrainSingleTransformed (std::vector< std::list< Evt * > > &sorted, unsigned int depth)
 Internal helper: train one (sub)tree. More...
 
static TMatrixD RandomOrthoMatrix ()
 Internal helper. Generate a random set of orthogonal unit vectors. More...
 

Protected Attributes

std::vector< Tree * > fTrees
 

Friends

voidthread_func (void *)
 

Detailed Description

"Random forest" of decision trees

All user interaction should be with this class

Definition at line 42 of file DecisionTree.h.

Constructor & Destructor Documentation

lem::dec::Forest::~Forest ( )

Definition at line 43 of file DecisionTree.cxx.

References fTrees, and confusionMatrixTree::t.

44  {
45  for(Tree* t: fTrees) delete t;
46  }
std::vector< Tree * > fTrees
Definition: DecisionTree.h:79
lem::dec::Forest::Forest ( )
inlineprotected

Definition at line 65 of file DecisionTree.h.

References lem::dec::thread_func().

Referenced by FromFile().

65 {}

Member Function Documentation

double lem::dec::Forest::Classify ( const Evt evt) const

Calculate the PID value for evt.

Definition at line 49 of file DecisionTree.cxx.

References fTrees, extractScale::mean, and confusionMatrixTree::t.

Referenced by lem::MakePID::produce(), test_dectree(), and train_dectree_caf().

50  {
51  // Average all the constituent trees
52  double mean = 0;
53  for(Tree* t: fTrees) mean += t->Classify(evt);
54  mean /= fTrees.size();
55  return mean;
56  }
std::vector< Tree * > fTrees
Definition: DecisionTree.h:79
int evt
Forest * lem::dec::Forest::FromFile ( const std::string fname)
static

Load PID from a file.

Definition at line 59 of file DecisionTree.cxx.

References ana::assert(), om::cout, Cut(), allTimeWatchdog::endl, MakeMiniprodValidationCuts::f, Forest(), genie::utils::style::Format(), fTrees, it, lem::dec::kNumPIDVars, art::left(), mat, runNovaSAM::ret, art::right(), and make_root_from_grid_output::tr.

Referenced by lem::MakePID::MakePID(), and test_dectree().

60  {
61  TFile f(fname.c_str());
62  assert(!f.IsZombie());
63  TTree* tr = (TTree*)f.Get("tree");
64  assert(tr);
65 
66  tr->SetBranchAddress("this", &gThis);
67  tr->SetBranchAddress("left", &gLeft);
68  tr->SetBranchAddress("right", &gRight);
69  tr->SetBranchAddress("cutdim", &gCutDim);
70  tr->SetBranchAddress("cutval", &gCutVal);
71  tr->SetBranchAddress("sig", &gSig);
72  tr->SetBranchAddress("bkg", &gBkg);
73 
74  // Map from TTree rows to pointers, use to connect up nodes
75  std::map<int, TreeNode*> ptrs;
76 
77  // For every row in the tree
78  const int N = tr->GetEntries();
79  for(int n = 0; n < N; ++n){
80  tr->GetEntry(n);
81 
82  if(gLeft == -1){
83  assert(gRight == -1);
84 
85  // If it has no children it's a simple Leaf
86  ptrs[gThis] = new Leaf(gSig, gBkg);
87  }
88  else{
89  assert(gRight != -1);
90 
91  // If it does have children, it's a Cut. We should have seen its
92  // children already, so look them up in the map.
93 
94  TreeNode* left = ptrs[gLeft];
95  TreeNode* right = ptrs[gRight];
96  assert(left && right);
97 
98  // Only retain unparented nodes in the map
99  ptrs.erase(ptrs.find(gLeft));
100  ptrs.erase(ptrs.find(gRight));
101 
102  ptrs[gThis] = new Cut(gCutDim, gCutVal, left, right);
103  }
104  }
105 
106  // Anything that never found a parent is a tree root
107  std::cout << "Loaded " << ptrs.size() << " trees" << std::endl;
108  Forest* ret = new Forest;
109 
110  int matIdx = 0;
111  // Fallback in case file is old
112  TMatrixD unit(kNumPIDVars+1, kNumPIDVars+1);
113  unit.UnitMatrix();
114 
115  for(auto it: ptrs){
116  TMatrixD* mat = (TMatrixD*)f.Get(TString::Format("mat_%d", matIdx++));
117  ret->fTrees.push_back(new Tree(it.second, mat ? *mat : unit));
118  }
119 
120  return ret;
121  }
constexpr auto const & right(const_AssnsIter< L, R, D, Dir > const &a, const_AssnsIter< L, R, D, Dir > const &b)
Definition: AssnsIter.h:104
const int kNumPIDVars
Definition: DecisionTree.h:23
set< int >::iterator it
void Cut(double x)
Definition: plot_outliers.C:1
std::void_t< T > n
Float_t mat
Definition: plot.C:39
OStream cout
Definition: OStream.cxx:6
constexpr auto const & left(const_AssnsIter< L, R, D, Dir > const &a, const_AssnsIter< L, R, D, Dir > const &b)
Definition: AssnsIter.h:96
assert(nhit_max >=nhit_nbins)
void Format(TGraph *gr, int lcol, int lsty, int lwid, int mcol, int msty, double msiz)
Definition: Style.cxx:154
TMatrixD lem::dec::Forest::RandomOrthoMatrix ( )
staticprotected

Internal helper. Generate a random set of orthogonal unit vectors.

Definition at line 150 of file DecisionTree.cxx.

References std::cos(), MECModelEnuComparisons::i, calib::j, lem::dec::kNumPIDVars, M_PI, r(), runNovaSAM::ret, and std::sin().

Referenced by TrainSingle().

151  {
152  // Start with the identity matrix
154  ret.UnitMatrix();
155 
156  TRandom3 r(0); // auto-seed
157 
158  // Random rotation by all the Euler angles
159  for(int i = 0; i < kNumPIDVars-1; ++i){
160  for(int j = i+1; j < kNumPIDVars; ++j){
161  TMatrixD mul(kNumPIDVars+1, kNumPIDVars+1);
162  mul.UnitMatrix();
163 
164  double ang = r.Uniform(2*M_PI);
165  mul(i, i) = mul(j, j) = cos(ang);
166  mul(i, j) = +sin(ang);
167  mul(j, i) = -sin(ang);
168 
169  ret = mul*ret;
170  }
171  }
172 
173  return ret;
174  }
const int kNumPIDVars
Definition: DecisionTree.h:23
#define M_PI
Definition: SbMath.h:34
const double j
Definition: BetheBloch.cxx:29
T sin(T number)
Definition: d0nt_math.hpp:132
T cos(T number)
Definition: d0nt_math.hpp:78
TRandom3 r(0)
void lem::dec::Forest::ToFile ( const std::string fname)

Write out PID structure to a file.

Definition at line 124 of file DecisionTree.cxx.

References genie::utils::style::Format(), submit_syst::fout, fTrees, MECModelEnuComparisons::i, and make_root_from_grid_output::tr.

Referenced by train_dectree_caf().

125  {
126  TFile fout(fname.c_str(), "RECREATE");
127  TTree tr("tree", "tree");
128 
129  // The ToTree functions of the individual nodes require us to have done this setup
130  tr.Branch("this", &gThis);
131  tr.Branch("left", &gLeft);
132  tr.Branch("right", &gRight);
133  tr.Branch("cutdim", &gCutDim);
134  tr.Branch("cutval", &gCutVal);
135  tr.Branch("sig", &gSig);
136  tr.Branch("bkg", &gBkg);
137 
138  // And they maintain the state of gPtrs
139  gPtrs.clear();
140 
141  for(unsigned int i = 0; i < fTrees.size(); ++i){
142  fTrees[i]->ToTree(&tr);
143  fTrees[i]->GetMatrix().Write(TString::Format("mat_%u", i));
144  }
145 
146  tr.Write("tree");
147  }
std::vector< Tree * > fTrees
Definition: DecisionTree.h:79
void Format(TGraph *gr, int lcol, int lsty, int lwid, int mcol, int msty, double msiz)
Definition: Style.cxx:154
Forest lem::dec::Forest::Train ( std::vector< Evt > &  trainEvts,
unsigned int  nTrees,
bool  parallel = false 
)
static

Initial training of the PID.

Parameters
trainEvtsThe events to be used for training
nTreesHow many trees to train. The final PID is an average of all
parallelUse multiple cores?

Definition at line 203 of file DecisionTree.cxx.

References make_syst_table_plots::args, om::cerr, om::cout, e, allTimeWatchdog::endl, fTrees, MECModelEnuComparisons::i, lem::dec::kNumPIDVars, extractScale::mean, nThreads, runNovaSAM::ret, extractScale::rms, util::sqr(), std::sqrt(), thread_func, make_root_from_grid_output::tr, TrainSingle(), lem::dec::Evt::vars, W, lem::dec::Evt::weight, and ana::weight.

Referenced by train_dectree_caf().

205  {
206  Forest ret;
207 
208  const unsigned int N = trainEvts.size();
209 
210  // Because we only take half the sample each time below
211  for(unsigned int n = 0; n < N; ++n) trainEvts[n].weight *= 2;
212 
213  // Accumulate the mean and RMS for each variable
214  double mean[kNumPIDVars] = {0,};
215  double rms[kNumPIDVars] = {0,};
216  double W = 0;
217  for(unsigned int n = 0; n < N; ++n){
218  const Evt& e = trainEvts[n];
219  for(int i = 0; i < kNumPIDVars; ++i){
220  if(fabs(e.vars[i]) < 10){ // There are some crazy values in there
221  mean[i] += e.weight*e.vars[i];
222  rms[i] += e.weight*util::sqr(e.vars[i]);
223  }
224  }
225  W += e.weight;
226  }
227 
228  // Translation matrix. To learn how this works, read
229  // http://en.wikipedia.org/wiki/Translation_matrix
230  TMatrixD transMat(kNumPIDVars+1, kNumPIDVars+1);
231  transMat.UnitMatrix();
232 
233  // Matrix to normalize coordinates into sigmas
234  TMatrixD scaleMat(kNumPIDVars+1, kNumPIDVars+1);
235  scaleMat(kNumPIDVars, kNumPIDVars) = 1;
236 
237  for(int i = 0; i < kNumPIDVars; ++i){
238  mean[i] /= W;
239  rms[i] = rms[i]/W-util::sqr(mean[i]);
240 
241  transMat(i, kNumPIDVars) = -mean[i];
242  scaleMat(i, i) = 1/sqrt(rms[i]);
243  }
244 
245 
246  if(parallel){
247  // Somewhat oversubscribe the number of cores, to keep them all busy
248  const int nThreads = sysconf(_SC_NPROCESSORS_ONLN)*1.5;
249 
250  std::vector<pthread_t*> ths;
251  int live = 0; // counter how many threads are currently running
252 
253  ThreadArgs args = {&trainEvts, scaleMat, transMat};
254 
255  while(true){
256  // Start new threads if needed
257  while(live < nThreads && ret.fTrees.size()+live < nTrees){
258  ths.push_back(new pthread_t);
259  pthread_create(ths.back(), 0, &thread_func, &args);
260  ++live;
261  }
262 
263  // Break when we're done
264  if(live == 0 && ret.fTrees.size() >= nTrees) break;
265 
266  // Have to busy-wait because sleep() suspends all the threads
267 
268  // Check all the threads to see if they've finished
269  for(unsigned int i = 0; i < ths.size(); ++i){
270  void* tr;
271 #ifndef DARWINBUILD
272  if(ths[i] && pthread_tryjoin_np(*ths[i], &tr) == 0){
273 #else
274  std::cerr << "Building on OSX requires using pthread_join rather than pthread_join_np" << std::endl;
275  if(ths[i] && pthread_join(*ths[i], &tr) == 0){
276 #endif
277  // If so, clean them up and store their results
278  delete ths[i];
279  ths[i] = 0;
280  ret.fTrees.push_back((Tree*)tr);
281  --live;
282  std::cout << "." << std::flush;
283  }
284  } // end for i
285  } // end while
286  std::cout << std::endl;
287  }
288  else{
289  for(unsigned int n = 0; n < nTrees; ++n){
290  ret.fTrees.push_back(TrainSingle(trainEvts, scaleMat, transMat));
291  std::cout << "." << std::flush;
292  }
293  std::cout << std::endl;
294  }
295 
296  // Try to put things back as we found them
297  for(unsigned int n = 0; n < N; ++n) trainEvts[n].weight /= 2;
298 
299  return ret;
300  }
const int kNumPIDVars
Definition: DecisionTree.h:23
const Var weight
T sqrt(T number)
Definition: d0nt_math.hpp:156
OStream cerr
Definition: OStream.cxx:7
T sqr(T x)
More efficient square function than pow(x,2)
Definition: MathUtil.h:23
friend void * thread_func(void *)
std::void_t< T > n
const int nThreads
Definition: PhotonSim_mp.C:69
OStream cout
Definition: OStream.cxx:6
static Tree * TrainSingle(const std::vector< Evt > &trainEvts, const TMatrixD &scaleMat, const TMatrixD &transMat)
Float_t e
Definition: plot.C:35
#define W(x)
Tree * lem::dec::Forest::TrainSingle ( const std::vector< Evt > &  trainEvts,
const TMatrixD scaleMat,
const TMatrixD transMat 
)
staticprotected

Definition at line 303 of file DecisionTree.cxx.

References a, b, MECModelEnuComparisons::i, lem::dec::kNumPIDVars, mat, lem::dec::TreeNode::Prune(), RandomOrthoMatrix(), getGoodRuns4SAM::test, TrainSingleTransformed(), lem::dec::TransformEvent(), compareCafs::tree, registry_explorer::v, and lem::dec::Evt::vars.

Referenced by lem::dec::thread_func(), and Train().

306  {
307  // To avoid linear artifacts in the output, we rotate into a random
308  // basis. Note that this is different for each tree in the forest.
309  const TMatrixD rotMat = RandomOrthoMatrix();
310 
311  // Before rotating, center and normalize the variables so they're all
312  // treated on an equal footing.
313  const TMatrixD mat = rotMat * scaleMat * transMat;
314 
315  const unsigned int N = trainEvts.size();
316 
317  std::vector<Evt> transformedEvts;
318  for(unsigned int i = 0; i < N; ++i){
319  transformedEvts.push_back(TransformEvent(trainEvts[i], mat));
320  }
321 
322  // Shuffle the training events and take the first half. This ensures
323  // each tree gets a different sample, so on average they don't
324  // overtrain.
325  std::random_device rng;
326  std::mt19937 urng(rng());
327  std::shuffle(transformedEvts.begin(), transformedEvts.end(), urng);
328 
329  // The training algorithm requires multiple lists of the events, sorted
330  // according to each PID variable.
331  std::vector<std::list<Evt*>> sorted(kNumPIDVars);
332  for(int v = 0; v < kNumPIDVars; ++v){
333  for(unsigned int i = 0; i < N/2; ++i){
334  sorted[v].push_back(&transformedEvts[i]);
335  }
336 
337  sorted[v].sort([v](const Evt* a, const Evt* b)
338  {
339  return a->vars[v] < b->vars[v];
340  });
341  }
342 
343  // Train the tree and add it to the forest
344  TreeNode* tree = TrainSingleTransformed(sorted, 0);
345  std::vector<Evt> test(transformedEvts.begin()+N/2, transformedEvts.end());
346  tree = tree->Prune(test);
347  return new Tree(tree, mat);
348  }
const int kNumPIDVars
Definition: DecisionTree.h:23
static TreeNode * TrainSingleTransformed(std::vector< std::list< Evt * > > &sorted, unsigned int depth)
Internal helper: train one (sub)tree.
Evt TransformEvent(const Evt &evt, const TMatrixD &mat)
static TMatrixD RandomOrthoMatrix()
Internal helper. Generate a random set of orthogonal unit vectors.
const double a
Float_t mat
Definition: plot.C:39
const hit & b
Definition: hits.cxx:21
TreeNode * lem::dec::Forest::TrainSingleTransformed ( std::vector< std::list< Evt * > > &  sorted,
unsigned int  depth 
)
staticprotected

Internal helper: train one (sub)tree.

Definition at line 357 of file DecisionTree.cxx.

References Cut(), febshutoff_auto::end, evt, it, lem::dec::kNumPIDVars, art::left(), next(), art::right(), util::sqr(), registry_explorer::v, and w.

Referenced by TrainSingle().

359  {
360  const unsigned int N = sorted[0].size();
361 
362  double totSig = 0;
363  double totBkg = 0;
364 
365  for(Evt* evt : sorted[0]){
366  const double w = evt->weight;
367  if(evt->isSig) totSig += w; else totBkg += w;
368  }
369 
370  // Variety of reasons not to try to cut anymore
371  if(depth == kDepthLimit || N < kMinBucketSize || totSig == 0 || totBkg == 0)
372  return new Leaf(totSig, totBkg);
373 
374  double bestFOMSq = 0;
375  // Convention that the invalid iterator is taken from the first list
376  std::list<Evt*>::iterator bestCutIt = sorted[0].end();
377  int bestCutDim = 0;
378 
379  // Try a cut in every dimension
380  for(int dim = 0; dim < kNumPIDVars; ++dim){
381  // Counts below the cut
382  double s1 = 0;
383  double b1 = 0;
384  for(std::list<Evt*>::iterator it = sorted[dim].begin(); it != sorted[dim].end(); ++it){
385  // Keep the counts below and above up to date
386  const double w = (*it)->weight;
387  if((*it)->isSig) s1 += w; else b1 += w;
388 
389  const double s2 = totSig-s1;
390  const double b2 = totBkg-b1;
391 
392  if(s1+b1 == 0 || s2+b2 == 0) continue;
393 
394  // Figure of merit if we were to cut here
395  const double fomSq = util::sqr(s1)/(s1+b1)+util::sqr(s2)/(s2+b2);
396 
397  if(fomSq > bestFOMSq){
398  bestFOMSq = fomSq;
399  if(it != sorted[dim].begin()){
400  // Definition of the cut is the last iterator in the "left" sample
401  bestCutIt = it;
402  --bestCutIt;
403  }
404  bestCutDim = dim;
405  }
406  } // end for n
407  } // end for dim
408 
409  // Didn't find a cut, or it's no better
410  if(bestCutIt == sorted[0].end() ||
411  bestFOMSq <= util::sqr(totSig)/(totSig+totBkg))
412  return new Leaf(totSig, totBkg);
413 
414  // First iterator in the "right" sample
415  std::list<Evt*>::iterator next = bestCutIt;
416  ++next;
417 
418  // Cut is halfway between the two points we want to seperate
419  const double cutPos = ((*bestCutIt)->vars[bestCutDim]+(*next)->vars[bestCutDim])/2;
420 
421  // Children require the same sorted lists that we got
422  std::vector<std::list<Evt*> > sortedLeft(kNumPIDVars);
423  std::vector<std::list<Evt*> > sortedRight(kNumPIDVars);
424 
425  for(int v = 0; v < kNumPIDVars; ++v){
426  for(Evt* evt : sorted[v]){
427  if(evt->vars[bestCutDim] < cutPos)
428  sortedLeft[v].push_back(evt);
429  else
430  sortedRight[v].push_back(evt);
431  }
432 
433  // Make sure to clear lists out once we don't need them, save memory
434  // while recursing.
435  sorted[v].clear();
436  }
437  sorted.clear();
438 
439  TreeNode* left = TrainSingleTransformed(sortedLeft, depth+1);
440  TreeNode* right = TrainSingleTransformed(sortedRight, depth+1);
441  return new Cut(bestCutDim, cutPos, left, right);
442  }
constexpr auto const & right(const_AssnsIter< L, R, D, Dir > const &a, const_AssnsIter< L, R, D, Dir > const &b)
Definition: AssnsIter.h:104
const int kNumPIDVars
Definition: DecisionTree.h:23
set< int >::iterator it
static TreeNode * TrainSingleTransformed(std::vector< std::list< Evt * > > &sorted, unsigned int depth)
Internal helper: train one (sub)tree.
T sqr(T x)
More efficient square function than pow(x,2)
Definition: MathUtil.h:23
void Cut(double x)
Definition: plot_outliers.C:1
int evt
const unsigned int kMinBucketSize
Once there are less than this many events in a subtree, stop subdividing.
constexpr auto const & left(const_AssnsIter< L, R, D, Dir > const &a, const_AssnsIter< L, R, D, Dir > const &b)
Definition: AssnsIter.h:96
Float_t w
Definition: plot.C:20
void next()
Definition: show_event.C:84
const unsigned int kDepthLimit
Don&#39;t build a tree deeper than this.

Friends And Related Function Documentation

void* thread_func ( void a)
friend

Definition at line 196 of file DecisionTree.cxx.

Referenced by Train().

197  {
198  ThreadArgs* args = (ThreadArgs*)a;
199  return Forest::TrainSingle(*args->evts, args->scale, args->trans);
200  }
const double a
static Tree * TrainSingle(const std::vector< Evt > &trainEvts, const TMatrixD &scaleMat, const TMatrixD &transMat)

Member Data Documentation

std::vector<Tree*> lem::dec::Forest::fTrees
protected

Definition at line 79 of file DecisionTree.h.

Referenced by Classify(), FromFile(), ToFile(), Train(), and ~Forest().


The documentation for this class was generated from the following files: