2 * svrt is the ``Synthetic Visual Reasoning Test'', an image
3 * generator for evaluating classification performance of machine
4 * learning systems, humans and primates.
6 * Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7 * Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 * This file is part of svrt.
11 * svrt is free software: you can redistribute it and/or modify it
12 * under the terms of the GNU General Public License version 3 as
13 * published by the Free Software Foundation.
15 * svrt is distributed in the hope that it will be useful, but
16 * WITHOUT ANY WARRANTY; without even the implied warranty of
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 * General Public License for more details.
20 * You should have received a copy of the GNU General Public License
21 * along with svrt. If not, see <http://www.gnu.org/licenses/>.
25 #include "boosted_classifier.h"
26 #include "classifier_reader.h"
27 #include "fusion_sort.h"
30 inline scalar_t loss_derivative(int label, scalar_t response) {
31 return - scalar_t(label * 2 - 1) * exp( - scalar_t(label * 2 - 1) * response );
34 BoostedClassifier::BoostedClassifier() {
39 BoostedClassifier::BoostedClassifier(int nb_weak_learners) {
40 _nb_stumps = nb_weak_learners;
41 _stumps = new Stump[_nb_stumps];
44 BoostedClassifier::~BoostedClassifier() {
48 const char *BoostedClassifier::name() {
52 void BoostedClassifier::chose_stump_from_sampling(int t, int **integral_images, scalar_t *derivatives, int nb_samples) {
53 int *indexes = new int[nb_samples];
54 int *sorted_indexes = new int[nb_samples];
55 scalar_t *stump_counts = new scalar_t[nb_samples];
57 scalar_t max_loss_derivative = 0;
63 for(int k = 0; k < global.nb_optimization_weak_learners; k++) {
67 for(int n = 0; n < nb_samples; n++) {
68 stump_counts[n] = tmp.count(integral_images[n]);
73 indexed_fusion_sort(nb_samples, indexes, sorted_indexes, stump_counts);
75 for(int n = 0; n < nb_samples - 1; n++) {
76 int i = sorted_indexes[n];
77 int j = sorted_indexes[n + 1];
78 s -= 2 * derivatives[i];
79 if(stump_counts[j] > stump_counts[i]) {
80 if(abs(s) > abs(max_loss_derivative)) {
81 max_loss_derivative = s;
83 _stumps[t].threshold = (stump_counts[i] + stump_counts[j])/2;
89 delete[] stump_counts;
91 delete[] sorted_indexes;
94 void BoostedClassifier::chose_stump(int t, int **integral_images, scalar_t *derivatives, int nb_samples) {
95 if(global.nb_sampled_samples <= 0) {
96 chose_stump_from_sampling(t, integral_images, derivatives, nb_samples);
98 int *sampled_indexes = new int[global.nb_sampled_samples];
99 scalar_t *weights = new scalar_t[nb_samples];
100 for(int s = 0; s < nb_samples; s++) {
101 weights[s] = abs(derivatives[s]);
103 robust_sampling(nb_samples, weights, global.nb_sampled_samples, sampled_indexes);
106 int **sampled_integral_images = new int *[global.nb_sampled_samples];
107 scalar_t *sampled_derivatives = new scalar_t[global.nb_sampled_samples];
109 for(int s = 0; s < global.nb_sampled_samples; s++) {
110 sampled_integral_images[s] = integral_images[sampled_indexes[s]];
111 if(derivatives[sampled_indexes[s]] > 0) {
112 sampled_derivatives[s] = 1.0;
114 sampled_derivatives[s] = -1.0;
118 chose_stump_from_sampling(t, sampled_integral_images, sampled_derivatives, global.nb_sampled_samples);
120 delete[] sampled_derivatives;
121 delete[] sampled_integral_images;
122 delete[] sampled_indexes;
126 void BoostedClassifier::train(int nb_vignettes, Vignette *vignettes, int *labels) {
127 int **integral_images = new int *[nb_vignettes];
129 for(int n = 0; n < nb_vignettes; n++) {
130 integral_images[n] = new int[(Vignette::width + 1) * (Vignette::height + 1)];
131 compute_integral_image(&vignettes[n], integral_images[n]);
134 scalar_t *responses = new scalar_t[nb_vignettes];
135 scalar_t *derivatives = new scalar_t[nb_vignettes];
137 for(int n = 0; n < nb_vignettes; n++) {
141 global.bar.init(&cout, _nb_stumps);
142 for(int t = 0; t < _nb_stumps; t++) {
144 for(int n = 0; n < nb_vignettes; n++) {
145 derivatives[n] = loss_derivative(labels[n], responses[n]);
148 chose_stump(t, integral_images, derivatives, nb_vignettes);
150 scalar_t num0 = 0, den0 = 0, num1 = 0, den1 = 0;
152 for(int n = 0; n < nb_vignettes; n++) {
153 if(_stumps[t].response(integral_images[n]) > 0) {
155 num1 += exp( - responses[n] );
157 den1 += exp( responses[n] );
161 num0 += exp( - responses[n] );
163 den0 += exp( responses[n] );
168 scalar_t weight_max = 5.0;
170 _stumps[t].weight0 = 0.5 * log(num0 / den0);
172 if(_stumps[t].weight0 < -weight_max)
173 _stumps[t].weight0 = -weight_max;
174 else if(_stumps[t].weight0 > weight_max)
175 _stumps[t].weight0 = weight_max;
177 _stumps[t].weight1 = 0.5 * log(num1 / den1);
178 if(_stumps[t].weight1 < -weight_max)
179 _stumps[t].weight1 = -weight_max;
180 else if(_stumps[t].weight1 > weight_max)
181 _stumps[t].weight1 = weight_max;
183 for(int n = 0; n < nb_vignettes; n++) {
184 responses[n] += _stumps[t].response(integral_images[n]);
187 // cout << "ADABOOST_STEP " << t + 1 << " " << loss << endl;
188 global.bar.refresh(&cout, t);
190 global.bar.finish(&cout);
193 for(int n = 0; n < nb_vignettes; n++) {
194 loss += exp( - scalar_t(labels[n] * 2 - 1) * responses[n]);
197 cout << "Final loss is " << loss << endl;
199 delete[] derivatives;
202 for(int n = 0; n < nb_vignettes; n++) {
203 delete[] integral_images[n];
206 delete[] integral_images;
209 scalar_t BoostedClassifier::classify(Vignette *vignette) {
210 int integral_image[(Vignette::width + 1) * (Vignette::height + 1)];
211 compute_integral_image(vignette, integral_image);
213 for(int n = 0; n < _nb_stumps; n++) {
214 result += _stumps[n].response(integral_image);
219 void BoostedClassifier::read(istream *in) {
221 read_var(in, &_nb_stumps);
222 cout << "Reading " << _nb_stumps << " stumps." << endl;
223 _stumps = new Stump[_nb_stumps];
224 in->read((char *) _stumps, sizeof(Stump) * _nb_stumps);
227 void BoostedClassifier::write(ostream *out) {
231 write_var(out, &_nb_stumps);
232 out->write((char *) _stumps, sizeof(Stump) * _nb_stumps);
235 scalar_t BoostedClassifier::partial_sum(int first, int nb, Vignette *vignette) {
236 int integral_image[(Vignette::width + 1) * (Vignette::height + 1)];
237 compute_integral_image(vignette, integral_image);
239 for(int n = first; n < first + nb; n++) {
240 result += _stumps[n].response(integral_image);