2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation. //
7 // This program is distributed in the hope that it will be useful, but //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU //
10 // General Public License for more details. //
12 // You should have received a copy of the GNU General Public License //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>. //
15 // Written by Francois Fleuret, (C) IDIAP //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports //
17 ///////////////////////////////////////////////////////////////////////////
22 #include "classifier_reader.h"
23 #include "pose_cell_hierarchy_reader.h"
25 Detector::Detector() {
28 _nb_classifiers_per_level = 0;
32 _pi_feature_families = 0;
36 Detector::~Detector() {
39 for(int q = 0; q < _nb_classifiers; q++) {
40 delete _classifiers[q];
41 delete _pi_feature_families[q];
43 delete[] _classifiers;
44 delete[] _pi_feature_families;
49 //////////////////////////////////////////////////////////////////////
52 void Detector::train_classifier(int level,
53 LossMachine *loss_machine,
54 ParsingPool *parsing_pool,
55 PiFeatureFamily *pi_feature_family,
56 Classifier *classifier) {
58 // Randomize the pi-feature family
60 PiFeatureFamily full_pi_feature_family;
62 full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
63 full_pi_feature_family.randomize(level);
65 int nb_positives = parsing_pool->nb_positive_cells();
67 int nb_negatives_to_sample =
68 parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
70 SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
71 nb_positives + nb_negatives_to_sample);
73 scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
75 (*global.log_stream) << "Collecting the sampled training set." << endl;
77 parsing_pool->weighted_sampling(loss_machine,
78 &full_pi_feature_family,
82 (*global.log_stream) << "Training the classifier." << endl;
84 (*global.log_stream) << "Initial train_loss "
85 << loss_machine->loss(sample_set, responses)
88 classifier->train(loss_machine, sample_set, responses);
89 classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
95 void Detector::train(LabelledImagePool *train_pool,
96 LabelledImagePool *validation_pool,
97 LabelledImagePool *hierarchy_pool) {
100 cerr << "Can not re-train a Detector" << endl;
104 _hierarchy = new PoseCellHierarchy(hierarchy_pool);
108 nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
110 if(nb_violations > 0) {
111 cout << "The hierarchy is incompatible with the training set ("
113 << " violations)." << endl;
117 nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
119 if(nb_violations > 0) {
120 cout << "The hierarchy is incompatible with the validation set ("
121 << nb_violations << " violations)."
126 _nb_levels = _hierarchy->nb_levels();
127 _nb_classifiers_per_level = global.nb_classifiers_per_level;
128 _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
129 _thresholds = new scalar_t[_nb_classifiers];
130 _classifiers = new Classifier *[_nb_classifiers];
131 _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
133 for(int q = 0; q < _nb_classifiers; q++) {
134 _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
135 _pi_feature_families[q] = new PiFeatureFamily();
138 ParsingPool *train_parsing, *validation_parsing;
140 train_parsing = new ParsingPool(train_pool,
142 global.proportion_negative_cells_for_training);
144 if(global.write_validation_rocs) {
145 validation_parsing = new ParsingPool(validation_pool,
147 global.proportion_negative_cells_for_training);
149 validation_parsing = 0;
152 LossMachine *loss_machine = new LossMachine(global.loss_type);
154 cout << "Building a detector." << endl;
156 global.bar.init(&cout, _nb_classifiers);
158 for(int l = 0; l < _nb_levels; l++) {
161 train_parsing->down_one_level(loss_machine, _hierarchy, l);
162 if(validation_parsing) {
163 validation_parsing->down_one_level(loss_machine, _hierarchy, l);
167 for(int c = 0; c < _nb_classifiers_per_level; c++) {
168 int q = l * _nb_classifiers_per_level + c;
170 (*global.log_stream) << "Building classifier " << q << " (level " << l << ")" << endl;
172 // Train the classifier
177 _pi_feature_families[q], _classifiers[q]);
179 // Update the cell responses on the training set
181 (*global.log_stream) << "Updating training cell responses." << endl;
183 train_parsing->update_cell_responses(_pi_feature_families[q],
186 // Save the ROC curves on the training set
188 char buffer[buffer_size];
190 sprintf(buffer, "%s/train_%05d.roc",
192 (q + 1) * global.nb_weak_learners_per_classifier);
193 ofstream out(buffer);
194 train_parsing->write_roc(&out);
196 if(validation_parsing) {
198 // Update the cell responses on the validation set
200 (*global.log_stream) << "Updating validation cell responses." << endl;
202 validation_parsing->update_cell_responses(_pi_feature_families[q],
205 // Save the ROC curves on the validation set
207 sprintf(buffer, "%s/validation_%05d.roc",
209 (q + 1) * global.nb_weak_learners_per_classifier);
210 ofstream out(buffer);
211 validation_parsing->write_roc(&out);
214 _thresholds[q] = 0.0;
216 global.bar.refresh(&cout, q);
220 global.bar.finish(&cout);
223 delete train_parsing;
224 delete validation_parsing;
227 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
228 LabelledImage *image;
229 int nb_targets_total = 0;
231 for(int i = 0; i < validation_pool->nb_images(); i++) {
232 image = validation_pool->grab_image(i);
233 nb_targets_total += image->nb_targets();
234 validation_pool->release_image(i);
237 scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
241 for(int i = 0; i < validation_pool->nb_images(); i++) {
242 image = validation_pool->grab_image(i);
243 image->compute_rich_structure();
245 PoseCell current_cell;
247 for(int t = 0; t < image->nb_targets(); t++) {
249 scalar_t response = 0;
251 for(int l = 0; l < _nb_levels; l++) {
253 // We get the next-level cell for that target
255 PoseCellSet cell_set;
257 cell_set.erase_content();
259 _hierarchy->add_root_cells(image, &cell_set);
261 _hierarchy->add_subcells(l, ¤t_cell, &cell_set);
264 int nb_compliant = 0;
266 for(int c = 0; c < cell_set.nb_cells(); c++) {
267 if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
268 current_cell = *(cell_set.get_cell(c));
273 if(nb_compliant != 1) {
274 cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
278 for(int c = 0; c < _nb_classifiers_per_level; c++) {
279 int q = l * _nb_classifiers_per_level + c;
280 SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
281 sample_set->set_sample(0, _pi_feature_families[q], image, ¤t_cell, 0);
282 response +=_classifiers[q]->response(sample_set, 0);
284 responses[tt + nb_targets_total * q] = response;
292 validation_pool->release_image(i);
295 ASSERT(tt == nb_targets_total);
297 // Here we have in responses[] all the target responses after every
300 int *still_detected = new int[nb_targets_total];
301 int *indexes = new int[nb_targets_total];
302 int *sorted_indexes = new int[nb_targets_total];
304 for(int t = 0; t < nb_targets_total; t++) {
305 still_detected[t] = 1;
309 int current_nb_fn = 0;
311 for(int q = 0; q < _nb_classifiers; q++) {
313 scalar_t wanted_tp_at_this_classifier
314 = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
316 int wanted_nb_fn_at_this_classifier
317 = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
319 (*global.log_stream) << "q = " << q
320 << " wanted_tp_at_this_classifier = " << wanted_tp_at_this_classifier
321 << " wanted_nb_fn_at_this_classifier = " << wanted_nb_fn_at_this_classifier
324 indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
325 responses + q * nb_targets_total);
327 for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
328 int u = sorted_indexes[t];
329 int v = sorted_indexes[t+1];
330 _thresholds[q] = responses[v + nb_targets_total * q];
331 if(still_detected[u]) {
332 still_detected[u] = 0;
338 delete[] still_detected;
340 delete[] sorted_indexes;
342 { ////////////////////////////////////////////////////////////////////
345 int nb_positives = 0;
347 for(int t = 0; t < nb_targets_total; t++) {
349 for(int q = 0; q < _nb_classifiers; q++) {
350 if(responses[t + nb_targets_total * q] < _thresholds[q]) positive = 0;
352 if(positive) nb_positives++;
355 scalar_t actual_tp = scalar_t(nb_positives) / scalar_t(nb_targets_total);
357 (*global.log_stream) << "Overall detection rate " << nb_positives << "/" << nb_targets_total
359 << "actual_tp = " << actual_tp
361 << "wanted_tp = " << wanted_tp
364 if(actual_tp < wanted_tp) {
365 cerr << "INCONSISTENCY" << endl;
368 } ////////////////////////////////////////////////////////////////////
373 //////////////////////////////////////////////////////////////////////
376 void Detector::parse_rec(RichImage *image, int level,
377 PoseCell *cell, scalar_t current_response,
378 PoseCellScoredSet *result) {
380 if(level == _nb_levels) {
381 result->add_cell_with_score(cell, current_response);
385 PoseCellSet cell_set;
386 cell_set.erase_content();
389 _hierarchy->add_root_cells(image, &cell_set);
391 _hierarchy->add_subcells(level, cell, &cell_set);
394 scalar_t *responses = new scalar_t[cell_set.nb_cells()];
395 int *keep = new int[cell_set.nb_cells()];
397 for(int c = 0; c < cell_set.nb_cells(); c++) {
398 responses[c] = current_response;
402 for(int a = 0; a < _nb_classifiers_per_level; a++) {
403 int q = level * _nb_classifiers_per_level + a;
404 SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
405 for(int c = 0; c < cell_set.nb_cells(); c++) {
407 samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
408 responses[c] += _classifiers[q]->response(samples, 0);
409 keep[c] = responses[c] >= _thresholds[q];
415 for(int c = 0; c < cell_set.nb_cells(); c++) {
417 parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
425 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
426 result_cell_set->erase_content();
427 parse_rec(image, 0, 0, 0, result_cell_set);
430 //////////////////////////////////////////////////////////////////////
433 void Detector::read(istream *is) {
435 cerr << "Can not read over an existing Detector" << endl;
439 read_var(is, &_nb_levels);
440 read_var(is, &_nb_classifiers_per_level);
442 _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
444 _classifiers = new Classifier *[_nb_classifiers];
445 _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
446 _thresholds = new scalar_t[_nb_classifiers];
448 for(int q = 0; q < _nb_classifiers; q++) {
449 cout << "Read classifier " << q << endl;
450 _pi_feature_families[q] = new PiFeatureFamily();
451 _pi_feature_families[q]->read(is);
452 _classifiers[q] = read_classifier(is);
453 read_var(is, &_thresholds[q]);
456 _hierarchy = read_hierarchy(is);
458 (*global.log_stream) << "Read Detector" << endl
459 << " _nb_levels " << _nb_levels << endl
460 << " _nb_classifiers_per_level " << _nb_classifiers_per_level << endl;
463 void Detector::write(ostream *os) {
464 write_var(os, &_nb_levels);
465 write_var(os, &_nb_classifiers_per_level);
467 for(int q = 0; q < _nb_classifiers; q++) {
468 _pi_feature_families[q]->write(os);
469 _classifiers[q]->write(os);
470 write_var(os, &_thresholds[q]);
473 _hierarchy->write(os);