DecisionTree.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 /// \file DecisionTree.h
3 /// \brief Decision Tree PID
4 /// \author Christopher Backhouse - bckhouse@caltech.edu
5 ////////////////////////////////////////////////////////////////////////
6 
7 #ifndef LEM_DECISIONTREE_H
8 #define LEM_DECISIONTREE_H
9 
10 #include <list>
11 #include <string>
12 #include <vector>
13 
14 #include "TMatrixD.h"
15 
16 class TTree;
17 
18 namespace lem
19 {
20  /// Decision tree %PID
21  namespace dec
22  {
23  const int kNumPIDVars = 6;
24 
25  /// \brief A training or trial event for the decision tree
26  ///
27  /// For classification you need only fill \a vars
28  struct Evt
29  {
30  double vars[kNumPIDVars];
31  int ccnc, pdg;
32  bool isSig;
33  double weight;
34  };
35 
36  class Tree;
37  class TreeNode;
38 
39  /// \brief "Random forest" of decision trees
40  ///
41  /// All user interaction should be with this class
42  class Forest
43  {
44  public:
45  ~Forest();
46 
47  /// Calculate the %PID value for \a evt
48  double Classify(const Evt& evt) const;
49 
50  /// Load %PID from a file
51  static Forest* FromFile(const std::string& fname);
52 
53  /// \brief Initial training of the %PID
54  ///
55  /// \param trainEvts The events to be used for training
56  /// \param nTrees How many trees to train. The final %PID is an average of all
57  /// \param parallel Use multiple cores?
58  static Forest Train(std::vector<Evt>& trainEvts,
59  unsigned int nTrees,
60  bool parallel = false);
61 
62  /// Write out %PID structure to a file
63  void ToFile(const std::string& fname);
64  protected:
65  Forest(){}
66 
67  friend void* thread_func(void*);
68  static Tree* TrainSingle(const std::vector<Evt>& trainEvts,
69  const TMatrixD& scaleMat,
70  const TMatrixD& transMat);
71 
72  /// Internal helper: train one (sub)tree
73  static TreeNode* TrainSingleTransformed(std::vector<std::list<Evt*> >& sorted,
74  unsigned int depth);
75 
76  /// Internal helper. Generate a random set of orthogonal unit vectors
77  static TMatrixD RandomOrthoMatrix();
78 
79  std::vector<Tree*> fTrees;
80  };
81 
82  /// Abstract base class for tree cuts and leaves
83  class TreeNode
84  {
85  public:
86  virtual ~TreeNode() {}
87  /// Estimated figure of merit of this subtree
88  virtual double FOM() const = 0;
89  /// Calculate %PID value of \a e
90  virtual double Classify(const Evt& e) const = 0;
91  /// Number of signal events below this point in the tree
92  virtual double NSig() const = 0;
93  /// Number of background events below this point in the tree
94  virtual double NBkg() const = 0;
95 
96  virtual TreeNode* Prune(std::vector<Evt>& evts) = 0;
97 
98  virtual bool IsLeaf() const = 0;
99 
100  /// Serialize to a TTree. Needs assistance from \ref Forest::ToTree
101  virtual void ToTree(TTree* tr) const = 0;
102  };
103 
104  /// A DecisionTree. Forwards most things to the head \ref TreeNode
105  class Tree
106  {
107  public:
109  : fHead(n), fMatrix(mat) {}
110  ~Tree() {delete fHead;}
111 
112  /// Estimated figure of merit of this tree
113  double FOM() const {return fHead->FOM();}
114  /// Calculate %PID value of \a e
115  double Classify(const Evt& e) const;
116 
117  void Prune(std::vector<Evt>& evts) {fHead = fHead->Prune(evts);};
118 
119  /// Serialize to a TTree. Needs assistance from \ref Forest::ToTree
120  void ToTree(TTree* tr) const {fHead->ToTree(tr);}
121 
122  TMatrixD GetMatrix() const {return fMatrix;}
123  protected:
126  };
127 
128  /// A cut dividing events into two samples, maximizing FOM
129  class Cut: public TreeNode
130  {
131  public:
132  Cut(int d, double v, TreeNode* l, TreeNode* r);
133  virtual double FOM() const;
134  virtual double Classify(const Evt& e) const;
135  virtual double NSig() const;
136  virtual double NBkg() const;
137 
138  virtual TreeNode* Prune(std::vector<Evt>& evts);
139 
140  bool IsLeaf() const {return false;}
141 
142  virtual void ToTree(TTree* tr) const;
143  protected:
144  int fCutDim; ///< What variable the cut is on
145  double fCutVal; ///< Where the cut is placed
146  TreeNode *fLeft, *fRight; ///< Subtrees to descend into based on cut result
147  };
148 
149  /// %Leaf of a decision tree. No further cuts are made
150  class Leaf: public TreeNode
151  {
152  public:
153  Leaf(double s, double b);
154  virtual double FOM() const;
155  virtual double Classify(const Evt&) const;
156  virtual double NSig() const {return fSig;}
157  virtual double NBkg() const {return fBkg;}
158 
159  virtual TreeNode* Prune(std::vector<Evt>&){return this;}
160 
161  bool IsLeaf() const {return true;}
162 
163  virtual void ToTree(TTree* tr) const;
164  protected:
165  double fSig; ///< How many signal events reached this node in training
166  double fBkg; ///< How many background events reached this node in training
167  };
168  } // namespace dec
169 } // namespace lem
170 
171 #endif
Leaf of a decision tree. No further cuts are made
Definition: DecisionTree.h:150
Tree(TreeNode *n, const TMatrixD &mat)
Definition: DecisionTree.h:108
const int kNumPIDVars
Definition: DecisionTree.h:23
double FOM() const
Estimated figure of merit of this tree.
Definition: DecisionTree.h:113
A cut dividing events into two samples, maximizing FOM.
Definition: DecisionTree.h:129
void ToTree(TTree *tr) const
Serialize to a TTree. Needs assistance from Forest::ToTree.
Definition: DecisionTree.h:120
bool IsLeaf() const
Definition: DecisionTree.h:140
double fSig
How many signal events reached this node in training.
Definition: DecisionTree.h:165
double fBkg
How many background events reached this node in training.
Definition: DecisionTree.h:166
std::vector< Tree * > fTrees
Definition: DecisionTree.h:79
A training or trial event for the decision tree.
Definition: DecisionTree.h:28
PID
Definition: FillPIDs.h:14
TMatrixD GetMatrix() const
Definition: DecisionTree.h:122
const XML_Char * s
Definition: expat.h:262
void Cut(double x)
Definition: plot_outliers.C:1
virtual double NBkg() const
Number of background events below this point in the tree.
Definition: DecisionTree.h:157
void * thread_func(void *a)
int fCutDim
What variable the cut is on.
Definition: DecisionTree.h:144
std::void_t< T > n
TreeNode * fHead
Definition: DecisionTree.h:124
"Random forest" of decision trees
Definition: DecisionTree.h:42
int evt
Float_t d
Definition: plot.C:236
Float_t mat
Definition: plot.C:39
TreeNode * fRight
Subtrees to descend into based on cut result.
Definition: DecisionTree.h:146
bool IsLeaf() const
Definition: DecisionTree.h:161
TMatrixD fMatrix
Definition: DecisionTree.h:125
virtual ~TreeNode()
Definition: DecisionTree.h:86
virtual TreeNode * Prune(std::vector< Evt > &)
Definition: DecisionTree.h:159
Abstract base class for tree cuts and leaves.
Definition: DecisionTree.h:83
void Prune(std::vector< Evt > &evts)
Definition: DecisionTree.h:117
double vars[kNumPIDVars]
Definition: DecisionTree.h:30
const hit & b
Definition: hits.cxx:21
virtual double NSig() const
Number of signal events below this point in the tree.
Definition: DecisionTree.h:156
TRandom3 r(0)
A DecisionTree. Forwards most things to the head TreeNode.
Definition: DecisionTree.h:105
Float_t e
Definition: plot.C:35
float fSig[xbins]
Definition: MakePlots.C:86
double fCutVal
Where the cut is placed.
Definition: DecisionTree.h:145
enum BeamMode string