automatic commit
[folded-ctf.git] / boosted_classifier.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published
13  *  by the Free Software Foundation, either version 3 of the License,
14  *  or (at your option) any later version.
15  *
16  *  folded-ctf is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 #include "classifier_reader.h"
27 #include "fusion_sort.h"
28
29 #include "boosted_classifier.h"
30 #include "tools.h"
31
32 BoostedClassifier::BoostedClassifier(int nb_weak_learners) {
33   _nb_weak_learners = nb_weak_learners;
34   _weak_learners = 0;
35 }
36
37 BoostedClassifier::BoostedClassifier() {
38   _nb_weak_learners = 0;
39   _weak_learners = 0;
40 }
41
42 BoostedClassifier::~BoostedClassifier() {
43   if(_weak_learners) {
44     for(int w = 0; w < _nb_weak_learners; w++)
45       delete _weak_learners[w];
46     delete[] _weak_learners;
47   }
48 }
49
50 scalar_t BoostedClassifier::response(SampleSet *sample_set, int n_sample) {
51   scalar_t r = 0;
52   for(int w = 0; w < _nb_weak_learners; w++) {
53     r += _weak_learners[w]->response(sample_set, n_sample);
54     ASSERT(!isnan(r));
55   }
56   return r;
57 }
58
59 void BoostedClassifier::train(LossMachine *loss_machine,
60                               SampleSet *sample_set, scalar_t *train_responses) {
61
62   if(_weak_learners) {
63     cerr << "Can not re-train a BoostedClassifier" << endl;
64     exit(1);
65   }
66
67   int nb_pos = 0, nb_neg = 0;
68
69   for(int s = 0; s < sample_set->nb_samples(); s++) {
70     if(sample_set->label(s) > 0) nb_pos++;
71     else if(sample_set->label(s) < 0) nb_neg++;
72   }
73
74   _weak_learners = new DecisionTree *[_nb_weak_learners];
75
76   (*global.log_stream) << "With " << nb_pos << " positive and " << nb_neg << " negative samples." << endl;
77
78   for(int w = 0; w  < _nb_weak_learners; w++) {
79
80     _weak_learners[w] = new DecisionTree();
81     _weak_learners[w]->train(loss_machine, sample_set, train_responses);
82
83     for(int n = 0; n < sample_set->nb_samples(); n++)
84       train_responses[n] += _weak_learners[w]->response(sample_set, n);
85
86     (*global.log_stream) << "Weak learner " << w
87          << " depth " << _weak_learners[w]->depth()
88          << " nb_leaves " << _weak_learners[w]->nb_leaves()
89          << " train loss " << loss_machine->loss(sample_set, train_responses)
90          << endl;
91
92   }
93
94   (*global.log_stream) << "Built a classifier with " << _nb_weak_learners << " weak_learners." << endl;
95 }
96
97 void BoostedClassifier::tag_used_features(bool *used) {
98   for(int w = 0; w < _nb_weak_learners; w++)
99     _weak_learners[w]->tag_used_features(used);
100 }
101
102 void BoostedClassifier::re_index_features(int *new_indexes) {
103   for(int w = 0; w < _nb_weak_learners; w++)
104     _weak_learners[w]->re_index_features(new_indexes);
105 }
106
107 void BoostedClassifier::read(istream *is) {
108   if(_weak_learners) {
109     cerr << "Can not read over an existing BoostedClassifier" << endl;
110     exit(1);
111   }
112
113   read_var(is, &_nb_weak_learners);
114   _weak_learners = new DecisionTree *[_nb_weak_learners];
115   for(int w = 0; w < _nb_weak_learners; w++) {
116     _weak_learners[w] = new DecisionTree();
117     _weak_learners[w]->read(is);
118     (*global.log_stream) << "Read tree " << w << " of depth "
119                          << _weak_learners[w]->depth() << " with "
120                          << _weak_learners[w]->nb_leaves() << " leaves." << endl;
121   }
122
123   (*global.log_stream)
124     << "Read BoostedClassifier containing " << _nb_weak_learners << " weak learners." << endl;
125 }
126
127 void BoostedClassifier::write(ostream *os) {
128   unsigned int id;
129   id = CLASSIFIER_BOOSTED;
130   write_var(os, &id);
131
132   write_var(os, &_nb_weak_learners);
133   for(int w = 0; w < _nb_weak_learners; w++)
134     _weak_learners[w]->write(os);
135 }