DoTraining.C
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <iostream>
3 #include <map>
4 #include <string>
5 #include "TChain.h"
6 #include "TFile.h"
7 #include "TTree.h"
8 #include "TString.h"
9 #include "TObjString.h"
10 #include "TSystem.h"
11 #include "TROOT.h"
12 #include "TMVA/MethodCategory.h"
13 #include "TMVA/Factory.h"
14 #include "TMVA/DataLoader.h"
15 #include "TMVA/Tools.h"
16 #include "TMVA/TMVAGui.h"
17 
18 void DoTraining(){
19 
20  // ------------------------------------------------
21  // stuff that requires user input is up front
22  // ------------------------------------------------
23 
24  // get signal and background trees from input file
25  TFile* inFile = new TFile("./TrainingTrees.root");
26  TTree* ncTree = (TTree*)inFile->Get("ncTree");
27  TTree* cosmicTree = (TTree*)inFile->Get("cosTree");
28 
29  // define weights for nc and cosmic events, i.e.
30  // weight these to the same POT/livetime
31  double ncWeight = 0.000136814;
32  double cosmicWeight = 0.010794024;
33 
34  TCut signalCuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
35  TCut backgroundCuts = ""; // for example: TCut mycutb = "abs(var1)<0.5";
36 
37  // create output file
38  TFile* outFile = new TFile("outFile.root", "recreate");
39 
40  // setup factory options
41  // -- V : verbose
42  // -- Silent : suppress all output
43  // -- Transformations : list of transformations to check:
44  // identity,
45  // decorreation
46  // PCA,
47  // gaussianisation followed by decorrelation,
48  // -- AnalysisType : classification or regression
49  //
50  // for additional information on Transformations see Section 4.1
51  // of the TMVA Users' Guide: https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf
52 
53  std::string factoryRunOpt = "V:!Silent";
54  // there are issues with the decorrelated varibles so
55  // neglect those options for now
56  //std::string transformations = "I;D;P;G,D" ;
57  std::string transformations = "I;P";
58  std::string analysisType = "Classification";
59 
60  std::string factoryOptions = factoryRunOpt +
61  ":Transformations=" + transformations +
62  ":AnalysisType=" + analysisType;
63 
64  // setup sample preparation options
65  // -- V : verbose
66  // -- nTrainSignal : number of signal events for training
67  // -- nTrainBackground : number of background events for training
68  // -- splitMode : defines how the training and test samples
69  // are selected from the source trees
70  // -- normMode : renormalisation of events
71 
72  std::string prepRunOpt = "V";
73  std::string nTrainSignal = "0";
74  std::string nTrainBackground = "0";
75  std::string splitMode = "Random";
76  // std::string normMode = "NumEvents";
77  std::string normMode = "None";
78 
79 
80  std::string prepOptions = prepRunOpt +
81  ":nTrain_Signal=" + nTrainSignal +
82  ":nTrain_Background=" + nTrainBackground +
83  ":SplitMode=" + splitMode +
84  ":NormMode=" + normMode;
85 
86  // setup sample preparation options
87  // -- V : verbose
88  // -- H : print help information
89  // -- InverseBoostNegWeights : treatment of events with negative
90  // weights
91  // -- nTrees : number of trees in the training forest
92  // -- boostType : boosting type for the trees in the forest
93  // (options are AdaBoost, RealAdaBoost,
94  // Bagging, AdBoostR2, Grad)
95  // -- shrinkage : learning rate for GradBoost algorithm
96  // -- nCuts : number of grid points in variable range
97  // used in finding optimal cut in node splitting
98  // -- maxDepth : max depth of decision trees allowed
99  // -- minNodeSize : minimum percentage of training events
100  // required in a leaf node
101  // -- node purity limit : in boosting/pruning, nodes with purity
102  // > than this value are signal, otherwise
103  // background
104 
105  std::string bookRunOpt = "V:!H:InverseBoostNegWeights";
106  std::string nTrees = "1000";
107  std::string shrinkage = "0.1";
108  std::string nCuts = "30";
109  std::string maxDepth = "5";
110  std::string minNodeSize = "2%";
111  std::string nodePurityLimit = "0.9";
112  std::vector<std::string> boostTypes =
113  {"Grad",
114  "RealAdaBoost"};
115  //"AdaBoost",
116  //"Bagging"};
117 
118  // -----------------------------------------------
119  // don't change stuff below here without knowing
120  // what you're doing
121  // -----------------------------------------------
122 
123  // load TMVA libraries
124  TMVA::Tools::Instance();
125 
126  TMVA::Factory *factory = new TMVA::Factory("TMVAClassifier",
127  outFile,
128  factoryOptions );
129 
130  // setup TMVA::DataLoader, and add training/spectator variables
131  TMVA::DataLoader* dataLoader = new TMVA::DataLoader("dataLoader");
132 
133  dataLoader->AddVariable("vtxx" , 'F');
134  dataLoader->AddVariable("vtxy" , 'F');
135  dataLoader->AddVariable("vtxz" , 'F');
136  dataLoader->AddVariable("closestslicemindist" , 'F');
137  dataLoader->AddVariable("ncontplanes" , 'I');
138  dataLoader->AddVariable("shwhitx" , 'I');
139  dataLoader->AddVariable("shwhity" , 'I');
140  dataLoader->AddVariable("shwhittot" , 'I');
141  dataLoader->AddVariable("shwhitasymm" , 'I');
142  dataLoader->AddVariable("shwhitratio" , 'F');
143  dataLoader->AddVariable("nshowers" , 'I');
144  dataLoader->AddVariable("showerwidth" , 'F');
145  dataLoader->AddVariable("showerlength" , 'F');
146  dataLoader->AddVariable("showerdirycosine" , 'F');
147  dataLoader->AddVariable("showergap" , 'F');
148  dataLoader->AddVariable("showercale" , 'F');
149  dataLoader->AddVariable("nhitsperplane" , 'F');
150  dataLoader->AddVariable("nmiphits" , 'I');
151  dataLoader->AddVariable("nhitsperslice" , 'F');
152  dataLoader->AddVariable("partptp" , 'F');
153  dataLoader->AddVariable("closestslicetime" , 'F');
154 
155  dataLoader->AddSpectator("run" , 'I');
156  dataLoader->AddSpectator("subrun" , 'I');
157  dataLoader->AddSpectator("evt" , 'I');
158  dataLoader->AddSpectator("subevt" , 'S');
159 
160  // add signal and background tree
161  dataLoader->AddSignalTree (ncTree , ncWeight );
162  dataLoader->AddBackgroundTree(cosmicTree, cosmicWeight);
163 
164  // prepare trees
165  dataLoader->PrepareTrainingAndTestTree(signalCuts, backgroundCuts, prepOptions);
166 
167  for (size_t i = 0; i < boostTypes.size(); ++i){
168 
169  std::string bookOptions = bookRunOpt +
170  ":NTrees=" + nTrees +
171  ":BoostType=" + boostTypes.at(i) +
172  ":Shrinkage=" + shrinkage +
173  ":nCuts=" + nCuts +
174  ":MaxDepth=" + maxDepth +
175  ":MinNodeSize=" + minNodeSize +
176  ":NodePurityLimit= " + nodePurityLimit;
177 
178  factory->BookMethod(dataLoader, TMVA::Types::kBDT, "BDTA"+boostTypes.at(i), bookOptions);
179  }
180 
181  factory->TrainAllMethods();
182  factory->TestAllMethods();
183  factory->EvaluateAllMethods();
184 
185  outFile->Close();
186 
187  delete factory;
188 
189  if (!gROOT->IsBatch()) TMVA::TMVAGui( outFile->GetName() );
190 
191 }
ifstream inFile
Definition: AnaPlotMaker.h:34
TFile * outFile
Definition: PlotXSec.C:135
void DoTraining()
Definition: DoTraining.C:18
enum BeamMode string