mubarid_training.C
Go to the documentation of this file.
1 //................................................
2 // This script has made to train/test mubar pid by
3 // proving 4 input variables such as:
4 // 1- dedxLL
5 // 2- ScatdedxLL
6 // 3- Average dedx in last 10cm
7 // 4 - Average dedx in last 40cm
8 //
9 // Author: Biswaranjan Behera
10 // Date: bbehera@fnal.gov
11 //...............................................
12 #include <cstdlib>
13 #include <iostream>
14 #include <map>
15 #include <string>
16 
17 #include "TChain.h"
18 #include "TFile.h"
19 #include "TTree.h"
20 #include "TString.h"
21 #include "TObjString.h"
22 #include "TSystem.h"
23 #include "TROOT.h"
24 
25 #include "TMVA/Tools.h"
26 #include "TMVA/Factory.h"
27 #include "TMVA/TMVAGui.h"
28 #include "TMVA/Reader.h"
29 #include "TMVA/DataLoader.h"
30 #include "TMVA/MethodBDT.h"
31 
32 using namespace TMVA;
33 
34 void mubarid_training(string foutput, string method = "", int ntrees = 500, float min_node_size = 2.5, int maxdepth = 4)
35 {
36  std::cout << "MuonID BDT Training" <<"\n"
37  << "......................................................." <<std::endl;
38 
39 
40  // Basic I/O
41  // Adding variables to Factory2 and also some spectator variables
42 
43  // Both trees are made of 20% of dataset
44  TFile *fFHCTrainFile = TFile::Open("/nova/ana/users/connorj/AntiNumuCC/MuonIDProd5/muonid_fhc_train_trim.root");
45  TFile *fRHCTrainFile = TFile::Open("/nova/ana/users/connorj/AntiNumuCC/MuonIDProd5/muonid_rhc_train_trim.root");
46  TFile *fFHCTestFile = TFile::Open("/nova/ana/users/connorj/AntiNumuCC/MuonIDProd5/muonid_fhc_test_trim.root");
47  TFile *fRHCTestFile = TFile::Open("/nova/ana/users/connorj/AntiNumuCC/MuonIDProd5/muonid_rhc_test_trim.root");
48 
49  // check your file is correct
50  assert(fFHCTrainFile && !fFHCTrainFile->IsZombie());
51  assert(fRHCTrainFile && !fRHCTrainFile->IsZombie());
52  assert(fFHCTestFile && !fFHCTestFile->IsZombie());
53  assert(fRHCTestFile && !fRHCTestFile->IsZombie());
54 
55  // Defining the output file
56  TFile* foutfile = TFile::Open(foutput.c_str(), "RECREATE");
57 
58  // Declaring the factory and variables
59  TMVA::Factory * factory = new TMVA::Factory("MuonID", foutfile, "V:!Color:DrawProgressBar:Transformations=I:AnalysisType=Classification");
60 
61  TMVA::DataLoader * dataloader = new TMVA::DataLoader("muonid_training_set");
62 
63  dataloader->AddVariable("DedxLL", "dE/dx LL", "", 'F');
64  dataloader->AddVariable("ScatLL", "Scattering LL", "", 'F');
65  dataloader->AddVariable("Avededxlast10cm", "Avededxlast10cm", "", 'F');
66  dataloader->AddVariable("Avededxlast40cm", "Avededxlast40cm", "", 'F');
67 
68  std::cout<<" Hurah!!... Variables and spectators are added"<<"\n"
69  <<".............................................."<<std::endl;
70 
71  // Register the training and test trees
72  TTree *sigFHCTrainTree = (TTree*)fFHCTrainFile->Get("sigTree");
73  TTree *bkgFHCTrainTree = (TTree*)fFHCTrainFile->Get("bkgTree");
74  TTree *sigFHCTestTree = (TTree*)fFHCTestFile ->Get("sigTree");
75  TTree *bkgFHCTestTree = (TTree*)fFHCTestFile ->Get("bkgTree");
76  TTree *sigRHCTrainTree = (TTree*)fRHCTrainFile->Get("sigTree");
77  TTree *bkgRHCTrainTree = (TTree*)fRHCTrainFile->Get("bkgTree");
78  TTree *sigRHCTestTree = (TTree*)fRHCTestFile ->Get("sigTree");
79  TTree *bkgRHCTestTree = (TTree*)fRHCTestFile ->Get("bkgTree");
80 
81  // fFHCTrainFile->Close();
82  // fRHCTrainFile->Close();
83  // fFHCTestFile ->Close();
84  // fRHCTestFile ->Close();
85 
86  // Specify which trees to use for training / testing (modify as you wish)
87  dataloader->AddSignalTree( sigFHCTrainTree, 1.0, TMVA::Types::kTraining);
88  dataloader->AddSignalTree( sigRHCTrainTree, 1.0, TMVA::Types::kTraining);
89  dataloader->AddSignalTree( sigFHCTestTree, 1.0, TMVA::Types::kTesting);
90  dataloader->AddSignalTree( sigRHCTestTree, 1.0, TMVA::Types::kTesting);
91  dataloader->AddBackgroundTree(bkgFHCTrainTree, 1.0, TMVA::Types::kTraining);
92  dataloader->AddBackgroundTree(bkgRHCTrainTree, 1.0, TMVA::Types::kTraining);
93  dataloader->AddBackgroundTree(bkgFHCTestTree, 1.0, TMVA::Types::kTesting);
94  dataloader->AddBackgroundTree(bkgRHCTestTree, 1.0, TMVA::Types::kTesting);
95 
96  std::cout<<"Added signal and background trees."<<std::endl;
97 
98  // Booking TMVA methods
99  // Boosted Decision Trees (Grad Boost)
100  // ....................................
101 
102  // ORIGINAL LOGIC, unoptimized
103  TMVA::MethodBDT * bdtMethod;
104  if (method == "grad"){
105  bdtMethod = (TMVA::MethodBDT*) factory->BookMethod(dataloader,TMVA::Types::kBDT, "BDTG_GradBoost", // Gradient Boost
106  "!H:V:BoostType=Grad:Shrinkage=0.50:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20");
107  }
108  else if (method == "ada"){
109  bdtMethod = (TMVA::MethodBDT*) factory->BookMethod(dataloader,TMVA::Types::kBDT, "BDT_AdaBoost", // AdaBoost
110  "!H:V:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20");
111  }
112  else if (method == "bagged"){
113  bdtMethod = (TMVA::MethodBDT*) factory->BookMethod(dataloader,TMVA::Types::kBDT, "BDTB_BaggedBoost", // Bagged Boost
114  "!H:!V:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
115  }
116  else{
117  throw runtime_error("Unknown BDT Method: " + method);
118  }
119  bdtMethod->SetNTrees(ntrees);
120  bdtMethod->SetMinNodeSize(min_node_size);
121  bdtMethod->SetMaxDepth(maxdepth);
122  cout << "User Defined Training Params:\n\tntrees = " << ntrees << "\n\tminnode=" << min_node_size << "\n\tdepth=" << maxdepth << endl;
123 
124  // Prepare the dataloader
125  cout << "Preparing..." << endl;
126  dataloader->PrepareTrainingAndTestTree("", -1, -1, -1, -1, "V");
127 
128  // // Train MVAs using the set of training events
129  std::cout<<"Training All Methods."<<std::endl;
130  factory->TrainAllMethods();
131 
132  // // Evaluate all MVAs using the set of test events
133  std::cout<<"Testing All Methods."<<std::endl;
134  factory->TestAllMethods();
135 
136  // // Evaluate and compare performance of all configured MVAs
137  std::cout<<"Evaluating All Methods!"<<std::endl;
138  factory->EvaluateAllMethods();
139 
140  // Save the output
141  foutfile->Close();
142 
143  //Launch the GUI
144  // if (!gROOT->IsBatch()) TMVA::TMVAGui( foutfile->GetName() );
145 
146  std::cout << "Output file is " << foutfile->GetName() << std::endl;
147  std::cout << "Mubar PID selection using TMVA is done sucessfully........ "<<"\n"
148  << "Next Step: Make plots using TMVA::Gui!"<<"\n"
149  << "-----------------------------xxx------------------------------"<<std::endl;
150 
151  //Launch the GUI
152  if (!gROOT->IsBatch()) TMVA::TMVAGui( foutfile->GetName() );
153 
154 
155  delete factory;
156 
157  // ...................................................
158 
159 } // end of void
cout<< "Opened file "<< fin<< " ixs= "<< ixs<< endl;if(ixs==0) hhh=(TH1F *) fff-> Get("h1")
Definition: AddMC.C:8
void mubarid_training(string foutput, string method="", int ntrees=500, float min_node_size=2.5, int maxdepth=4)
OStream cout
Definition: OStream.cxx:6
assert(nhit_max >=nhit_nbins)
Definition: tmvaglob.h:28