automatic commit
[folded-ctf.git] / detector.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published
13  *  by the Free Software Foundation, either version 3 of the License,
14  *  or (at your option) any later version.
15  *
16  *  folded-ctf is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 #include "tools.h"
27 #include "detector.h"
28 #include "global.h"
29 #include "classifier_reader.h"
30 #include "pose_cell_hierarchy_reader.h"
31
32 Detector::Detector() {
33   _hierarchy = 0;
34   _nb_levels = 0;
35   _nb_classifiers_per_level = 0;
36   _thresholds = 0;
37   _nb_classifiers = 0;
38   _classifiers = 0;
39   _pi_feature_families = 0;
40 }
41
42
43 Detector::~Detector() {
44   if(_hierarchy) {
45     delete[] _thresholds;
46     for(int q = 0; q < _nb_classifiers; q++) {
47       delete _classifiers[q];
48       delete _pi_feature_families[q];
49     }
50     delete[] _classifiers;
51     delete[] _pi_feature_families;
52     delete _hierarchy;
53   }
54 }
55
56 //////////////////////////////////////////////////////////////////////
57 // Training
58
59 void Detector::train_classifier(int level,
60                                 LossMachine *loss_machine,
61                                 ParsingPool *parsing_pool,
62                                 PiFeatureFamily *pi_feature_family,
63                                 Classifier *classifier) {
64
65   // Randomize the pi-feature family
66
67   PiFeatureFamily full_pi_feature_family;
68
69   full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
70   full_pi_feature_family.randomize(level);
71
72   int nb_positives = parsing_pool->nb_positive_cells();
73
74   int nb_negatives_to_sample =
75     parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
76
77   SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
78                                         nb_positives + nb_negatives_to_sample);
79
80   scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
81
82   parsing_pool->weighted_sampling(loss_machine,
83                                   &full_pi_feature_family,
84                                   sample_set,
85                                   responses);
86
87   (*global.log_stream) << "Initial train_loss "
88                        << loss_machine->loss(sample_set, responses)
89                        << endl;
90
91   classifier->train(loss_machine, sample_set, responses);
92   classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
93
94   delete[] responses;
95   delete sample_set;
96 }
97
98 void Detector::train(LabelledImagePool *train_pool,
99                      LabelledImagePool *validation_pool,
100                      LabelledImagePool *hierarchy_pool) {
101
102   if(_hierarchy) {
103     cerr << "Can not re-train a Detector" << endl;
104     exit(1);
105   }
106
107   _hierarchy = new PoseCellHierarchy(hierarchy_pool);
108
109   int nb_violations;
110
111   nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
112
113   if(nb_violations > 0) {
114     cout << "The hierarchy is incompatible with the training set ("
115          << nb_violations
116          << " violations)." << endl;
117     exit(1);
118   }
119
120   nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
121
122   if(nb_violations > 0) {
123     cout << "The hierarchy is incompatible with the validation set ("
124          << nb_violations << " violations)."
125          << endl;
126     exit(1);
127   }
128
129   _nb_levels = _hierarchy->nb_levels();
130   _nb_classifiers_per_level = global.nb_classifiers_per_level;
131   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
132   _thresholds = new scalar_t[_nb_classifiers];
133   _classifiers = new Classifier *[_nb_classifiers];
134   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
135
136   for(int q = 0; q < _nb_classifiers; q++) {
137     _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
138     _pi_feature_families[q] = new PiFeatureFamily();
139   }
140
141   ParsingPool *train_parsing, *validation_parsing;
142
143   train_parsing = new ParsingPool(train_pool,
144                                   _hierarchy,
145                                   global.proportion_negative_cells_for_training);
146
147   if(global.write_validation_rocs) {
148     validation_parsing = new ParsingPool(validation_pool,
149                                          _hierarchy,
150                                          global.proportion_negative_cells_for_training);
151   } else {
152     validation_parsing = 0;
153   }
154
155   LossMachine *loss_machine = new LossMachine(global.loss_type);
156
157   cout << "Building a detector." << endl;
158
159   global.bar.init(&cout, _nb_classifiers);
160
161   for(int l = 0; l < _nb_levels; l++) {
162
163     if(l > 0) {
164       train_parsing->down_one_level(loss_machine, _hierarchy, l);
165       if(validation_parsing) {
166         validation_parsing->down_one_level(loss_machine, _hierarchy, l);
167       }
168     }
169
170     for(int c = 0; c < _nb_classifiers_per_level; c++) {
171       int q = l * _nb_classifiers_per_level + c;
172
173       // Train the classifier
174
175       train_classifier(l,
176                        loss_machine,
177                        train_parsing,
178                        _pi_feature_families[q], _classifiers[q]);
179
180       // Update the cell responses on the training set
181
182       train_parsing->update_cell_responses(_pi_feature_families[q],
183                                            _classifiers[q]);
184
185       // Save the ROC curves on the training set
186
187       char buffer[buffer_size];
188
189       sprintf(buffer, "%s/train_%05d.roc",
190               global.result_path,
191               (q + 1) * global.nb_weak_learners_per_classifier);
192       ofstream out(buffer);
193       train_parsing->write_roc(&out);
194
195       if(validation_parsing) {
196
197         // Update the cell responses on the validation set
198
199         validation_parsing->update_cell_responses(_pi_feature_families[q],
200                                                   _classifiers[q]);
201
202         // Save the ROC curves on the validation set
203
204         sprintf(buffer, "%s/validation_%05d.roc",
205                 global.result_path,
206                 (q + 1) * global.nb_weak_learners_per_classifier);
207         ofstream out(buffer);
208         validation_parsing->write_roc(&out);
209       }
210
211       _thresholds[q] = 0.0;
212
213       global.bar.refresh(&cout, q);
214     }
215   }
216
217   global.bar.finish(&cout);
218
219   delete loss_machine;
220   delete train_parsing;
221   delete validation_parsing;
222 }
223
224 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
225   LabelledImage *image;
226   int nb_targets_total = 0;
227
228   for(int i = 0; i < validation_pool->nb_images(); i++) {
229     image = validation_pool->grab_image(i);
230     nb_targets_total += image->nb_targets();
231     validation_pool->release_image(i);
232   }
233
234   scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
235
236   int tt = 0;
237
238   for(int i = 0; i < validation_pool->nb_images(); i++) {
239     image = validation_pool->grab_image(i);
240     image->compute_rich_structure();
241
242     PoseCell current_cell;
243
244     for(int t = 0; t < image->nb_targets(); t++) {
245
246       scalar_t response = 0;
247
248       for(int l = 0; l < _nb_levels; l++) {
249
250         // We get the next-level cell for that target
251
252         PoseCellSet cell_set;
253
254         cell_set.erase_content();
255         if(l == 0) {
256           _hierarchy->add_root_cells(image, &cell_set);
257         } else {
258           _hierarchy->add_subcells(l, &current_cell, &cell_set);
259         }
260
261         int nb_compliant = 0;
262
263         for(int c = 0; c < cell_set.nb_cells(); c++) {
264           if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
265             current_cell = *(cell_set.get_cell(c));
266             nb_compliant++;
267           }
268         }
269
270         if(nb_compliant != 1) {
271           cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
272           abort();
273         }
274
275         for(int c = 0; c < _nb_classifiers_per_level; c++) {
276           int q = l * _nb_classifiers_per_level + c;
277           SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
278           sample_set->set_sample(0, _pi_feature_families[q], image, &current_cell, 0);
279           response +=_classifiers[q]->response(sample_set, 0);
280           delete sample_set;
281           responses[tt + nb_targets_total * q] = response;
282         }
283
284       }
285
286       tt++;
287     }
288
289     validation_pool->release_image(i);
290   }
291
292   ASSERT(tt == nb_targets_total);
293
294   // Here we have in responses[] all the target responses after every
295   // classifier
296
297   int *still_detected = new int[nb_targets_total];
298   int *indexes = new int[nb_targets_total];
299   int *sorted_indexes = new int[nb_targets_total];
300
301   for(int t = 0; t < nb_targets_total; t++) {
302     still_detected[t] = 1;
303     indexes[t] = t;
304   }
305
306   int current_nb_fn = 0;
307
308   for(int q = 0; q < _nb_classifiers; q++) {
309
310     scalar_t wanted_tp_at_this_classifier
311       = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
312
313     int wanted_nb_fn_at_this_classifier
314       = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
315
316     indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
317                         responses + q * nb_targets_total);
318
319     for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
320       int u = sorted_indexes[t];
321       int v = sorted_indexes[t+1];
322       _thresholds[q] = responses[v + nb_targets_total * q];
323       if(still_detected[u]) {
324         still_detected[u] = 0;
325         current_nb_fn++;
326       }
327     }
328   }
329
330   delete[] still_detected;
331   delete[] indexes;
332   delete[] sorted_indexes;
333   delete[] responses;
334 }
335
336 //////////////////////////////////////////////////////////////////////
337 // Parsing
338
339 void Detector::parse_rec(RichImage *image, int level,
340                          PoseCell *cell, scalar_t current_response,
341                          PoseCellScoredSet *result) {
342
343   if(level == _nb_levels) {
344     result->add_cell_with_score(cell, current_response);
345     return;
346   }
347
348   PoseCellSet cell_set;
349   cell_set.erase_content();
350
351   if(level == 0) {
352     _hierarchy->add_root_cells(image, &cell_set);
353   } else {
354     _hierarchy->add_subcells(level, cell, &cell_set);
355   }
356
357   scalar_t *responses = new scalar_t[cell_set.nb_cells()];
358   int *keep = new int[cell_set.nb_cells()];
359
360   for(int c = 0; c < cell_set.nb_cells(); c++) {
361     responses[c] = current_response;
362     keep[c] = 1;
363   }
364
365   for(int a = 0; a < _nb_classifiers_per_level; a++) {
366     int q = level * _nb_classifiers_per_level + a;
367     SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
368     for(int c = 0; c < cell_set.nb_cells(); c++) {
369       if(keep[c]) {
370         samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
371         responses[c] += _classifiers[q]->response(samples, 0);
372         keep[c] = responses[c] >= _thresholds[q];
373       }
374     }
375     delete samples;
376   }
377
378   for(int c = 0; c < cell_set.nb_cells(); c++) {
379     if(keep[c]) {
380       parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
381     }
382   }
383
384   delete[] keep;
385   delete[] responses;
386 }
387
388 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
389   result_cell_set->erase_content();
390   parse_rec(image, 0, 0, 0, result_cell_set);
391 }
392
393 //////////////////////////////////////////////////////////////////////
394 // Storage
395
396 void Detector::read(istream *is) {
397   if(_hierarchy) {
398     cerr << "Can not read over an existing Detector" << endl;
399     exit(1);
400   }
401
402   read_var(is, &_nb_levels);
403   read_var(is, &_nb_classifiers_per_level);
404
405   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
406
407   _classifiers = new Classifier *[_nb_classifiers];
408   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
409   _thresholds = new scalar_t[_nb_classifiers];
410
411   for(int q = 0; q < _nb_classifiers; q++) {
412     _pi_feature_families[q] = new PiFeatureFamily();
413     _pi_feature_families[q]->read(is);
414     _classifiers[q] = read_classifier(is);
415     read_var(is, &_thresholds[q]);
416   }
417
418   _hierarchy = read_hierarchy(is);
419 }
420
421 void Detector::write(ostream *os) {
422   write_var(os, &_nb_levels);
423   write_var(os, &_nb_classifiers_per_level);
424
425   for(int q = 0; q < _nb_classifiers; q++) {
426     _pi_feature_families[q]->write(os);
427     _classifiers[q]->write(os);
428     write_var(os, &_thresholds[q]);
429   }
430
431   _hierarchy->write(os);
432 }