tmva_new_train.C
Go to the documentation of this file.
1 //
2 // Created by josh (Joshua Porter, for the souls that next retrain ReMId feel free to contact me jccporter@gmail.com)
3 // on 11/03/2019.
4 // This is the current version used for ReMId retraining for the 2020 analysis. The old version `tmva_remid_train_classifiers'
5 // is located in /scripts. It works for an older version of TMVA, contains snippets of code you might want to steal if
6 // you're considering using a different algorithm (not BDTG or KNN)
7 //
8 
9 #include <cstdlib>
10 #include <iostream>
11 #include <map>
12 #include <string>
13 
14 #include "TChain.h"
15 #include "TFile.h"
16 #include "TTree.h"
17 #include "TString.h"
18 #include "TObjString.h"
19 #include "TSystem.h"
20 #include "TROOT.h"
21 
22 #include "TMVA/Factory.h"
23 #include "TMVA/DataLoader.h"
24 #include "TMVA/Tools.h"
25 #include "TMVA/TMVAGui.h"
26 
27 
28 int tmva_new_train(TString myMethodList = "") {
29 
30  TMVA::Tools::Instance();
31 
32  std::map<std::string, int> Use;
33 
34 
35  Use["KNN"] = 0; // k-nearest neighbour method
36 
37  Use["BDTG"] = 1; // uses Gradient Boost
38 
39  //
40  // Friedman's RuleFit method, ie, an optimised series of cuts ("rules")
41  Use["RuleFit"] = 0;
42  // ---------------------------------------------------------------
43 
45  std::cout << "==> Start TMVAClassification" << std::endl;
46 
47  // Select methods (don't look at this code - not of interest)
48  if (myMethodList != "") {
49  for (std::map<std::string, int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
50 
51  std::vector <TString> mlist = TMVA::gTools().SplitString(myMethodList, ',');
52  for (UInt_t i = 0; i < mlist.size(); i++) {
53  std::string regMethod(mlist[i]);
54 
55  if (Use.find(regMethod) == Use.end()) {
56  std::cout << "Method \"" << regMethod
57  << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
58  for (std::map<std::string, int>::iterator it = Use.begin(); it != Use.end(); it++)
59  std::cout << it->first << " ";
60  std::cout << std::endl;
61  return 1;
62  }
63  Use[regMethod] = 1;
64  }
65  }
66 
67 
68  chdir("results");
69 
70  std::string trainingFileName = "tmva_train_trees.root";
71  std::string outputFileName = "TrainingOutput.root";
72 
73  TFile *siginput = TFile::Open(trainingFileName.c_str(),
74  "CACHEREAD");
75  TFile *outputFile = TFile::Open(outputFileName.c_str(),
76  "RECREATE");
77 
78  TTree *signalTree = (TTree *) siginput->Get("SigTree");
79  TTree *backgroundTree = (TTree *) siginput->Get("BackTree");
80 
81  TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
82  "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
83 
84  TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
85 
86  dataloader->AddVariable("trackLength", "trackLen", "cm", 'F');
87  dataloader->AddVariable("dedxSep", "de_dxLLh", "GeV/cm", 'F');
88  dataloader->AddVariable("scatSep", "scatLLh", "degrees", 'F');
89  dataloader->AddVariable("measFrac", "measFracE", "", 'F');
90 
91 
92  Double_t signalWeight = 1.0;
93  Double_t backgroundWeight = 1.0;
94 
95  dataloader->AddSignalTree(signalTree, signalWeight);
96  dataloader->AddBackgroundTree(backgroundTree, backgroundWeight);
97 
98  int numSignalEvents = signalTree->Draw("","1");
99  int numBackgroundEvents = backgroundTree->Draw("","1");
100 
101 
102 
103 
104  // Apply additional cuts on the signal and background samples (can be different)
105  TCut mycuts = "";
106  TCut mycutb = "";
107 
108 
109  signalTree->SetEventList(0);
110  backgroundTree->SetEventList(0);
111 
112  std::string trainTestOptions;
113  float ratio = 0.9;
114  trainTestOptions += "nTrain_Signal=" + std::to_string(int(numSignalEvents*ratio));
115  trainTestOptions += ":nTrain_Background=" + std::to_string(int(numBackgroundEvents*ratio));
116  trainTestOptions += ":nTest_Signal=" + std::to_string(int(numSignalEvents*(1 - ratio)));
117  trainTestOptions += ":nTest_Background=" + std::to_string(int(numBackgroundEvents*(1 - ratio)));
118  //Split the data set randomly
119  trainTestOptions += ":SplitMode=Random";
120  //SplitSeed=0 random number generator starts with random seed
121  trainTestOptions += ":SplitSeed=100";
122 
123  dataloader->PrepareTrainingAndTestTree(mycuts, mycutb,
124  trainTestOptions);
125 
126  // Cut optimisation
127 
128  // K-Nearest Neighbour classifier (KNN)
129  if (Use["KNN"])
130  factory->BookMethod(dataloader, TMVA::Types::kKNN, "KNN",
131  "nkNN=80:ScaleFrac=0.8:SigmaFact=1.0:Kernel=Gaus:UseKernel=F:UseWeight=T:!Trim");
132 
133 
134  // Boosted Decision Trees
135  if (Use["BDTG"]) // Gradient Boost
136  factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG",
137  "!H:!V:NTrees=1000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=800:MaxDepth=24");
138 
139 
140  //
141  // Train MVAs using the set of training events
142  factory->TrainAllMethods();
143 
144  // Evaluate all MVAs using the set of test events
145  factory->TestAllMethods();
146 
147  // Evaluate and compare performance of all configured MVAs
148  factory->EvaluateAllMethods();
149 
150  // --------------------------------------------------------------
151 
152  // Save the output
153  outputFile->Close();
154 
155  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
156  std::cout << "==> TMVAClassification is done!" << std::endl;
157 
158  delete factory;
159  delete dataloader;
160  // Launch the GUI for the root macros
161  if (!gROOT->IsBatch()) TMVA::TMVAGui(outputFileName.c_str());
162 
163  return 0;
164 }
165 
166 int main(int argc, char **argv) {
167  // Select methods (don't look at this code - not of interest)
168  TString methodList;
169  for (int i = 1; i < argc; i++) {
170  TString regMethod(argv[i]);
171  if (regMethod == "-b" || regMethod == "--batch") continue;
172  if (!methodList.IsNull()) methodList += TString(",");
173  methodList += regMethod;
174  }
175  return tmva_new_train(methodList);
176 }
set< int >::iterator it
int main(int argc, char **argv)
TH1 * ratio(TH1 *h1, TH1 *h2)
int tmva_new_train(TString myMethodList="")
OStream cout
Definition: OStream.cxx:6
std::string to_string(ModuleType mt)
Definition: ModuleType.h:32
enum BeamMode string