automatic commit
[folded-ctf.git] / detector.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "tools.h"
20 #include "detector.h"
21 #include "global.h"
22 #include "classifier_reader.h"
23 #include "pose_cell_hierarchy_reader.h"
24
25 Detector::Detector() {
26   _hierarchy = 0;
27   _nb_levels = 0;
28   _nb_classifiers_per_level = 0;
29   _thresholds = 0;
30   _nb_classifiers = 0;
31   _classifiers = 0;
32   _pi_feature_families = 0;
33 }
34
35
36 Detector::~Detector() {
37   if(_hierarchy) {
38     delete[] _thresholds;
39     for(int q = 0; q < _nb_classifiers; q++) {
40       delete _classifiers[q];
41       delete _pi_feature_families[q];
42     }
43     delete[] _classifiers;
44     delete[] _pi_feature_families;
45     delete _hierarchy;
46   }
47 }
48
49 //////////////////////////////////////////////////////////////////////
50 // Training
51
52 void Detector::train_classifier(int level,
53                                 LossMachine *loss_machine,
54                                 ParsingPool *parsing_pool,
55                                 PiFeatureFamily *pi_feature_family,
56                                 Classifier *classifier) {
57
58   // Randomize the pi-feature family
59
60   PiFeatureFamily full_pi_feature_family;
61
62   full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
63   full_pi_feature_family.randomize(level);
64
65   int nb_positives = parsing_pool->nb_positive_cells();
66
67   int nb_negatives_to_sample =
68     parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
69
70   SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
71                                         nb_positives + nb_negatives_to_sample);
72
73   scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
74
75   (*global.log_stream) << "Collecting the sampled training set." << endl;
76
77   parsing_pool->weighted_sampling(loss_machine,
78                                   &full_pi_feature_family,
79                                   sample_set,
80                                   responses);
81
82   (*global.log_stream) << "Training the classifier." << endl;
83
84   (*global.log_stream) << "Initial train_loss "
85                        << loss_machine->loss(sample_set, responses)
86                        << endl;
87
88   classifier->train(loss_machine, sample_set, responses);
89   classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
90
91   delete[] responses;
92   delete sample_set;
93 }
94
95 void Detector::train(LabelledImagePool *train_pool,
96                      LabelledImagePool *validation_pool,
97                      LabelledImagePool *hierarchy_pool) {
98
99   if(_hierarchy) {
100     cerr << "Can not re-train a Detector" << endl;
101     exit(1);
102   }
103
104   _hierarchy = new PoseCellHierarchy(hierarchy_pool);
105
106   int nb_violations;
107
108   nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
109
110   if(nb_violations > 0) {
111     cout << "The hierarchy is incompatible with the training set ("
112          << nb_violations
113          << " violations)." << endl;
114     exit(1);
115   }
116
117   nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
118
119   if(nb_violations > 0) {
120     cout << "The hierarchy is incompatible with the validation set ("
121          << nb_violations << " violations)."
122          << endl;
123     exit(1);
124   }
125
126   _nb_levels = _hierarchy->nb_levels();
127   _nb_classifiers_per_level = global.nb_classifiers_per_level;
128   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
129   _thresholds = new scalar_t[_nb_classifiers];
130   _classifiers = new Classifier *[_nb_classifiers];
131   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
132
133   for(int q = 0; q < _nb_classifiers; q++) {
134     _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
135     _pi_feature_families[q] = new PiFeatureFamily();
136   }
137
138   ParsingPool *train_parsing, *validation_parsing;
139
140   train_parsing = new ParsingPool(train_pool,
141                                   _hierarchy,
142                                   global.proportion_negative_cells_for_training);
143
144   if(global.write_validation_rocs) {
145     validation_parsing = new ParsingPool(validation_pool,
146                                          _hierarchy,
147                                          global.proportion_negative_cells_for_training);
148   } else {
149     validation_parsing = 0;
150   }
151
152   LossMachine *loss_machine = new LossMachine(global.loss_type);
153
154   cout << "Building a detector." << endl;
155
156   global.bar.init(&cout, _nb_classifiers);
157
158   for(int l = 0; l < _nb_levels; l++) {
159
160     if(l > 0) {
161       train_parsing->down_one_level(loss_machine, _hierarchy, l);
162       if(validation_parsing) {
163         validation_parsing->down_one_level(loss_machine, _hierarchy, l);
164       }
165     }
166
167     for(int c = 0; c < _nb_classifiers_per_level; c++) {
168       int q = l * _nb_classifiers_per_level + c;
169
170       (*global.log_stream) << "Building classifier " << q << " (level " << l << ")" << endl;
171
172       // Train the classifier
173
174       train_classifier(l,
175                        loss_machine,
176                        train_parsing,
177                        _pi_feature_families[q], _classifiers[q]);
178
179       // Update the cell responses on the training set
180
181       (*global.log_stream) << "Updating training cell responses." << endl;
182
183       train_parsing->update_cell_responses(_pi_feature_families[q],
184                                            _classifiers[q]);
185
186       // Save the ROC curves on the training set
187
188       char buffer[buffer_size];
189
190       sprintf(buffer, "%s/train_%05d.roc",
191               global.result_path,
192               (q + 1) * global.nb_weak_learners_per_classifier);
193       ofstream out(buffer);
194       train_parsing->write_roc(&out);
195
196       if(validation_parsing) {
197
198         // Update the cell responses on the validation set
199
200         (*global.log_stream) << "Updating validation cell responses." << endl;
201
202         validation_parsing->update_cell_responses(_pi_feature_families[q],
203                                                   _classifiers[q]);
204
205         // Save the ROC curves on the validation set
206
207         sprintf(buffer, "%s/validation_%05d.roc",
208                 global.result_path,
209                 (q + 1) * global.nb_weak_learners_per_classifier);
210         ofstream out(buffer);
211         validation_parsing->write_roc(&out);
212       }
213
214       _thresholds[q] = 0.0;
215
216       global.bar.refresh(&cout, q);
217     }
218   }
219
220   global.bar.finish(&cout);
221
222   delete loss_machine;
223   delete train_parsing;
224   delete validation_parsing;
225 }
226
227 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
228   LabelledImage *image;
229   int nb_targets_total = 0;
230
231   for(int i = 0; i < validation_pool->nb_images(); i++) {
232     image = validation_pool->grab_image(i);
233     nb_targets_total += image->nb_targets();
234     validation_pool->release_image(i);
235   }
236
237   scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
238
239   int tt = 0;
240
241   for(int i = 0; i < validation_pool->nb_images(); i++) {
242     image = validation_pool->grab_image(i);
243     image->compute_rich_structure();
244
245     PoseCell current_cell;
246
247     for(int t = 0; t < image->nb_targets(); t++) {
248
249       scalar_t response = 0;
250
251       for(int l = 0; l < _nb_levels; l++) {
252
253         // We get the next-level cell for that target
254
255         PoseCellSet cell_set;
256
257         cell_set.erase_content();
258         if(l == 0) {
259           _hierarchy->add_root_cells(image, &cell_set);
260         } else {
261           _hierarchy->add_subcells(l, &current_cell, &cell_set);
262         }
263
264         int nb_compliant = 0;
265
266         for(int c = 0; c < cell_set.nb_cells(); c++) {
267           if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
268             current_cell = *(cell_set.get_cell(c));
269             nb_compliant++;
270           }
271         }
272
273         if(nb_compliant != 1) {
274           cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
275           abort();
276         }
277
278         for(int c = 0; c < _nb_classifiers_per_level; c++) {
279           int q = l * _nb_classifiers_per_level + c;
280           SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
281           sample_set->set_sample(0, _pi_feature_families[q], image, &current_cell, 0);
282           response +=_classifiers[q]->response(sample_set, 0);
283           delete sample_set;
284           responses[tt + nb_targets_total * q] = response;
285         }
286
287       }
288
289       tt++;
290     }
291
292     validation_pool->release_image(i);
293   }
294
295   ASSERT(tt == nb_targets_total);
296
297   // Here we have in responses[] all the target responses after every
298   // classifier
299
300   int *still_detected = new int[nb_targets_total];
301   int *indexes = new int[nb_targets_total];
302   int *sorted_indexes = new int[nb_targets_total];
303
304   for(int t = 0; t < nb_targets_total; t++) {
305     still_detected[t] = 1;
306     indexes[t] = t;
307   }
308
309   int current_nb_fn = 0;
310
311   for(int q = 0; q < _nb_classifiers; q++) {
312
313     scalar_t wanted_tp_at_this_classifier
314       = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
315
316     int wanted_nb_fn_at_this_classifier
317       = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
318
319     (*global.log_stream) << "q = " << q
320                          << " wanted_tp_at_this_classifier = " << wanted_tp_at_this_classifier
321                          << " wanted_nb_fn_at_this_classifier = " << wanted_nb_fn_at_this_classifier
322                          << endl;
323
324     indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
325                         responses + q * nb_targets_total);
326
327     for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
328       int u = sorted_indexes[t];
329       int v = sorted_indexes[t+1];
330       _thresholds[q] = responses[v + nb_targets_total * q];
331       if(still_detected[u]) {
332         still_detected[u] = 0;
333         current_nb_fn++;
334       }
335     }
336   }
337
338   delete[] still_detected;
339   delete[] indexes;
340   delete[] sorted_indexes;
341
342   { ////////////////////////////////////////////////////////////////////
343     // Sanity check
344
345     int nb_positives = 0;
346
347     for(int t = 0; t < nb_targets_total; t++) {
348       int positive = 1;
349       for(int q = 0; q < _nb_classifiers; q++) {
350         if(responses[t + nb_targets_total * q] < _thresholds[q]) positive = 0;
351       }
352       if(positive) nb_positives++;
353     }
354
355     scalar_t actual_tp = scalar_t(nb_positives) / scalar_t(nb_targets_total);
356
357     (*global.log_stream) << "Overall detection rate " << nb_positives << "/" << nb_targets_total
358                          << " "
359                          << "actual_tp = " << actual_tp
360                          << " "
361                          << "wanted_tp = " << wanted_tp
362                          << endl;
363
364     if(actual_tp < wanted_tp) {
365       cerr << "INCONSISTENCY" << endl;
366       abort();
367     }
368   } ////////////////////////////////////////////////////////////////////
369
370   delete[] responses;
371 }
372
373 //////////////////////////////////////////////////////////////////////
374 // Parsing
375
376 void Detector::parse_rec(RichImage *image, int level,
377                          PoseCell *cell, scalar_t current_response,
378                          PoseCellScoredSet *result) {
379
380   if(level == _nb_levels) {
381     result->add_cell_with_score(cell, current_response);
382     return;
383   }
384
385   PoseCellSet cell_set;
386   cell_set.erase_content();
387
388   if(level == 0) {
389     _hierarchy->add_root_cells(image, &cell_set);
390   } else {
391     _hierarchy->add_subcells(level, cell, &cell_set);
392   }
393
394   scalar_t *responses = new scalar_t[cell_set.nb_cells()];
395   int *keep = new int[cell_set.nb_cells()];
396
397   for(int c = 0; c < cell_set.nb_cells(); c++) {
398     responses[c] = current_response;
399     keep[c] = 1;
400   }
401
402   for(int a = 0; a < _nb_classifiers_per_level; a++) {
403     int q = level * _nb_classifiers_per_level + a;
404     SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
405     for(int c = 0; c < cell_set.nb_cells(); c++) {
406       if(keep[c]) {
407         samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
408         responses[c] += _classifiers[q]->response(samples, 0);
409         keep[c] = responses[c] >= _thresholds[q];
410       }
411     }
412     delete samples;
413   }
414
415   for(int c = 0; c < cell_set.nb_cells(); c++) {
416     if(keep[c]) {
417       parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
418     }
419   }
420
421   delete[] keep;
422   delete[] responses;
423 }
424
425 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
426   result_cell_set->erase_content();
427   parse_rec(image, 0, 0, 0, result_cell_set);
428 }
429
430 //////////////////////////////////////////////////////////////////////
431 // Storage
432
433 void Detector::read(istream *is) {
434   if(_hierarchy) {
435     cerr << "Can not read over an existing Detector" << endl;
436     exit(1);
437   }
438
439   read_var(is, &_nb_levels);
440   read_var(is, &_nb_classifiers_per_level);
441
442   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
443
444   _classifiers = new Classifier *[_nb_classifiers];
445   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
446   _thresholds = new scalar_t[_nb_classifiers];
447
448   for(int q = 0; q < _nb_classifiers; q++) {
449     _pi_feature_families[q] = new PiFeatureFamily();
450     _pi_feature_families[q]->read(is);
451     _classifiers[q] = read_classifier(is);
452     read_var(is, &_thresholds[q]);
453   }
454
455   _hierarchy = read_hierarchy(is);
456
457   (*global.log_stream) << "Read Detector" << endl
458                        << "  _nb_levels " << _nb_levels << endl
459                        << "  _nb_classifiers_per_level " << _nb_classifiers_per_level << endl;
460 }
461
462 void Detector::write(ostream *os) {
463   write_var(os, &_nb_levels);
464   write_var(os, &_nb_classifiers_per_level);
465
466   for(int q = 0; q < _nb_classifiers; q++) {
467     _pi_feature_families[q]->write(os);
468     _classifiers[q]->write(os);
469     write_var(os, &_thresholds[q]);
470   }
471
472   _hierarchy->write(os);
473 }