2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation. //
7 // This program is distributed in the hope that it will be useful, but //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU //
10 // General Public License for more details. //
12 // You should have received a copy of the GNU General Public License //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>. //
15 // Written by Francois Fleuret //
16 // (C) Idiap Research Institute //
18 // Contact <francois.fleuret@idiap.ch> for comments & bug reports //
19 ///////////////////////////////////////////////////////////////////////////
21 #include "classifier_reader.h"
22 #include "fusion_sort.h"
24 #include "boosted_classifier.h"
27 BoostedClassifier::BoostedClassifier(int nb_weak_learners) {
28 _nb_weak_learners = nb_weak_learners;
32 BoostedClassifier::BoostedClassifier() {
33 _nb_weak_learners = 0;
37 BoostedClassifier::~BoostedClassifier() {
39 for(int w = 0; w < _nb_weak_learners; w++)
40 delete _weak_learners[w];
41 delete[] _weak_learners;
45 scalar_t BoostedClassifier::response(SampleSet *sample_set, int n_sample) {
47 for(int w = 0; w < _nb_weak_learners; w++) {
48 r += _weak_learners[w]->response(sample_set, n_sample);
54 void BoostedClassifier::train(LossMachine *loss_machine,
55 SampleSet *sample_set, scalar_t *train_responses) {
58 cerr << "Can not re-train a BoostedClassifier" << endl;
62 int nb_pos = 0, nb_neg = 0;
64 for(int s = 0; s < sample_set->nb_samples(); s++) {
65 if(sample_set->label(s) > 0) nb_pos++;
66 else if(sample_set->label(s) < 0) nb_neg++;
69 _weak_learners = new DecisionTree *[_nb_weak_learners];
71 (*global.log_stream) << "With " << nb_pos << " positive and " << nb_neg << " negative samples." << endl;
73 for(int w = 0; w < _nb_weak_learners; w++) {
75 _weak_learners[w] = new DecisionTree();
76 _weak_learners[w]->train(loss_machine, sample_set, train_responses);
78 for(int n = 0; n < sample_set->nb_samples(); n++)
79 train_responses[n] += _weak_learners[w]->response(sample_set, n);
81 (*global.log_stream) << "Weak learner " << w
82 << " depth " << _weak_learners[w]->depth()
83 << " nb_leaves " << _weak_learners[w]->nb_leaves()
84 << " train loss " << loss_machine->loss(sample_set, train_responses)
89 (*global.log_stream) << "Built a classifier with " << _nb_weak_learners << " weak_learners." << endl;
92 void BoostedClassifier::tag_used_features(bool *used) {
93 for(int w = 0; w < _nb_weak_learners; w++)
94 _weak_learners[w]->tag_used_features(used);
97 void BoostedClassifier::re_index_features(int *new_indexes) {
98 for(int w = 0; w < _nb_weak_learners; w++)
99 _weak_learners[w]->re_index_features(new_indexes);
102 void BoostedClassifier::read(istream *is) {
104 cerr << "Can not read over an existing BoostedClassifier" << endl;
108 read_var(is, &_nb_weak_learners);
109 _weak_learners = new DecisionTree *[_nb_weak_learners];
110 for(int w = 0; w < _nb_weak_learners; w++) {
111 _weak_learners[w] = new DecisionTree();
112 _weak_learners[w]->read(is);
113 (*global.log_stream) << "Read tree " << w << " of depth "
114 << _weak_learners[w]->depth() << " with "
115 << _weak_learners[w]->nb_leaves() << " leaves." << endl;
119 << "Read BoostedClassifier containing " << _nb_weak_learners << " weak learners." << endl;
122 void BoostedClassifier::write(ostream *os) {
124 id = CLASSIFIER_BOOSTED;
127 write_var(os, &_nb_weak_learners);
128 for(int w = 0; w < _nb_weak_learners; w++)
129 _weak_learners[w]->write(os);