automatic commit
[folded-ctf.git] / detector.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include "tools.h"
20 #include "detector.h"
21 #include "global.h"
22 #include "classifier_reader.h"
23 #include "pose_cell_hierarchy_reader.h"
24
25 Detector::Detector() {
26   _hierarchy = 0;
27   _nb_levels = 0;
28   _nb_classifiers_per_level = 0;
29   _thresholds = 0;
30   _nb_classifiers = 0;
31   _classifiers = 0;
32   _pi_feature_families = 0;
33 }
34
35
36 Detector::~Detector() {
37   if(_hierarchy) {
38     delete[] _thresholds;
39     for(int q = 0; q < _nb_classifiers; q++) {
40       delete _classifiers[q];
41       delete _pi_feature_families[q];
42     }
43     delete[] _classifiers;
44     delete[] _pi_feature_families;
45     delete _hierarchy;
46   }
47 }
48
49 //////////////////////////////////////////////////////////////////////
50 // Training
51
52 void Detector::train_classifier(int level,
53                                 LossMachine *loss_machine,
54                                 ParsingPool *parsing_pool,
55                                 PiFeatureFamily *pi_feature_family,
56                                 Classifier *classifier) {
57
58   // Randomize the pi-feature family
59
60   PiFeatureFamily full_pi_feature_family;
61
62   full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
63   full_pi_feature_family.randomize(level);
64
65   int nb_positives = parsing_pool->nb_positive_cells();
66
67   int nb_negatives_to_sample =
68     parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
69
70   SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
71                                         nb_positives + nb_negatives_to_sample);
72
73   scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
74
75   parsing_pool->weighted_sampling(loss_machine,
76                                   &full_pi_feature_family,
77                                   sample_set,
78                                   responses);
79
80   (*global.log_stream) << "Initial train_loss "
81                        << loss_machine->loss(sample_set, responses)
82                        << endl;
83
84   classifier->train(loss_machine, sample_set, responses);
85   classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
86
87   delete[] responses;
88   delete sample_set;
89 }
90
91 void Detector::train(LabelledImagePool *train_pool,
92                      LabelledImagePool *validation_pool,
93                      LabelledImagePool *hierarchy_pool) {
94
95   if(_hierarchy) {
96     cerr << "Can not re-train a Detector" << endl;
97     exit(1);
98   }
99
100   _hierarchy = new PoseCellHierarchy(hierarchy_pool);
101
102   int nb_violations;
103
104   nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
105
106   if(nb_violations > 0) {
107     cout << "The hierarchy is incompatible with the training set ("
108          << nb_violations
109          << " violations)." << endl;
110     exit(1);
111   }
112
113   nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
114
115   if(nb_violations > 0) {
116     cout << "The hierarchy is incompatible with the validation set ("
117          << nb_violations << " violations)."
118          << endl;
119     exit(1);
120   }
121
122   _nb_levels = _hierarchy->nb_levels();
123   _nb_classifiers_per_level = global.nb_classifiers_per_level;
124   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
125   _thresholds = new scalar_t[_nb_classifiers];
126   _classifiers = new Classifier *[_nb_classifiers];
127   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
128
129   for(int q = 0; q < _nb_classifiers; q++) {
130     _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
131     _pi_feature_families[q] = new PiFeatureFamily();
132   }
133
134   ParsingPool *train_parsing, *validation_parsing;
135
136   train_parsing = new ParsingPool(train_pool,
137                                   _hierarchy,
138                                   global.proportion_negative_cells_for_training);
139
140   if(global.write_validation_rocs) {
141     validation_parsing = new ParsingPool(validation_pool,
142                                          _hierarchy,
143                                          global.proportion_negative_cells_for_training);
144   } else {
145     validation_parsing = 0;
146   }
147
148   LossMachine *loss_machine = new LossMachine(global.loss_type);
149
150   cout << "Building a detector." << endl;
151
152   global.bar.init(&cout, _nb_classifiers);
153
154   for(int l = 0; l < _nb_levels; l++) {
155
156     if(l > 0) {
157       train_parsing->down_one_level(loss_machine, _hierarchy, l);
158       if(validation_parsing) {
159         validation_parsing->down_one_level(loss_machine, _hierarchy, l);
160       }
161     }
162
163     for(int c = 0; c < _nb_classifiers_per_level; c++) {
164       int q = l * _nb_classifiers_per_level + c;
165
166       // Train the classifier
167
168       train_classifier(l,
169                        loss_machine,
170                        train_parsing,
171                        _pi_feature_families[q], _classifiers[q]);
172
173       // Update the cell responses on the training set
174
175       train_parsing->update_cell_responses(_pi_feature_families[q],
176                                            _classifiers[q]);
177
178       // Save the ROC curves on the training set
179
180       char buffer[buffer_size];
181
182       sprintf(buffer, "%s/train_%05d.roc",
183               global.result_path,
184               (q + 1) * global.nb_weak_learners_per_classifier);
185       ofstream out(buffer);
186       train_parsing->write_roc(&out);
187
188       if(validation_parsing) {
189
190         // Update the cell responses on the validation set
191
192         validation_parsing->update_cell_responses(_pi_feature_families[q],
193                                                   _classifiers[q]);
194
195         // Save the ROC curves on the validation set
196
197         sprintf(buffer, "%s/validation_%05d.roc",
198                 global.result_path,
199                 (q + 1) * global.nb_weak_learners_per_classifier);
200         ofstream out(buffer);
201         validation_parsing->write_roc(&out);
202       }
203
204       _thresholds[q] = 0.0;
205
206       global.bar.refresh(&cout, q);
207     }
208   }
209
210   global.bar.finish(&cout);
211
212   delete loss_machine;
213   delete train_parsing;
214   delete validation_parsing;
215 }
216
217 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
218   LabelledImage *image;
219   int nb_targets_total = 0;
220
221   for(int i = 0; i < validation_pool->nb_images(); i++) {
222     image = validation_pool->grab_image(i);
223     nb_targets_total += image->nb_targets();
224     validation_pool->release_image(i);
225   }
226
227   scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
228
229   int tt = 0;
230
231   for(int i = 0; i < validation_pool->nb_images(); i++) {
232     image = validation_pool->grab_image(i);
233     image->compute_rich_structure();
234
235     PoseCell current_cell;
236
237     for(int t = 0; t < image->nb_targets(); t++) {
238
239       scalar_t response = 0;
240
241       for(int l = 0; l < _nb_levels; l++) {
242
243         // We get the next-level cell for that target
244
245         PoseCellSet cell_set;
246
247         cell_set.erase_content();
248         if(l == 0) {
249           _hierarchy->add_root_cells(image, &cell_set);
250         } else {
251           _hierarchy->add_subcells(l, &current_cell, &cell_set);
252         }
253
254         int nb_compliant = 0;
255
256         for(int c = 0; c < cell_set.nb_cells(); c++) {
257           if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
258             current_cell = *(cell_set.get_cell(c));
259             nb_compliant++;
260           }
261         }
262
263         if(nb_compliant != 1) {
264           cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
265           abort();
266         }
267
268         for(int c = 0; c < _nb_classifiers_per_level; c++) {
269           int q = l * _nb_classifiers_per_level + c;
270           SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
271           sample_set->set_sample(0, _pi_feature_families[q], image, &current_cell, 0);
272           response +=_classifiers[q]->response(sample_set, 0);
273           delete sample_set;
274           responses[tt + nb_targets_total * q] = response;
275         }
276
277       }
278
279       tt++;
280     }
281
282     validation_pool->release_image(i);
283   }
284
285   ASSERT(tt == nb_targets_total);
286
287   // Here we have in responses[] all the target responses after every
288   // classifier
289
290   int *still_detected = new int[nb_targets_total];
291   int *indexes = new int[nb_targets_total];
292   int *sorted_indexes = new int[nb_targets_total];
293
294   for(int t = 0; t < nb_targets_total; t++) {
295     still_detected[t] = 1;
296     indexes[t] = t;
297   }
298
299   int current_nb_fn = 0;
300
301   for(int q = 0; q < _nb_classifiers; q++) {
302
303     scalar_t wanted_tp_at_this_classifier
304       = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
305
306     int wanted_nb_fn_at_this_classifier
307       = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
308
309     indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
310                         responses + q * nb_targets_total);
311
312     for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
313       int u = sorted_indexes[t];
314       int v = sorted_indexes[t+1];
315       _thresholds[q] = responses[v + nb_targets_total * q];
316       if(still_detected[u]) {
317         still_detected[u] = 0;
318         current_nb_fn++;
319       }
320     }
321   }
322
323   delete[] still_detected;
324   delete[] indexes;
325   delete[] sorted_indexes;
326   delete[] responses;
327 }
328
329 //////////////////////////////////////////////////////////////////////
330 // Parsing
331
332 void Detector::parse_rec(RichImage *image, int level,
333                          PoseCell *cell, scalar_t current_response,
334                          PoseCellScoredSet *result) {
335
336   if(level == _nb_levels) {
337     result->add_cell_with_score(cell, current_response);
338     return;
339   }
340
341   PoseCellSet cell_set;
342   cell_set.erase_content();
343
344   if(level == 0) {
345     _hierarchy->add_root_cells(image, &cell_set);
346   } else {
347     _hierarchy->add_subcells(level, cell, &cell_set);
348   }
349
350   scalar_t *responses = new scalar_t[cell_set.nb_cells()];
351   int *keep = new int[cell_set.nb_cells()];
352
353   for(int c = 0; c < cell_set.nb_cells(); c++) {
354     responses[c] = current_response;
355     keep[c] = 1;
356   }
357
358   for(int a = 0; a < _nb_classifiers_per_level; a++) {
359     int q = level * _nb_classifiers_per_level + a;
360     SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
361     for(int c = 0; c < cell_set.nb_cells(); c++) {
362       if(keep[c]) {
363         samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
364         responses[c] += _classifiers[q]->response(samples, 0);
365         keep[c] = responses[c] >= _thresholds[q];
366       }
367     }
368     delete samples;
369   }
370
371   for(int c = 0; c < cell_set.nb_cells(); c++) {
372     if(keep[c]) {
373       parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
374     }
375   }
376
377   delete[] keep;
378   delete[] responses;
379 }
380
381 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
382   result_cell_set->erase_content();
383   parse_rec(image, 0, 0, 0, result_cell_set);
384 }
385
386 //////////////////////////////////////////////////////////////////////
387 // Storage
388
389 void Detector::read(istream *is) {
390   if(_hierarchy) {
391     cerr << "Can not read over an existing Detector" << endl;
392     exit(1);
393   }
394
395   read_var(is, &_nb_levels);
396   read_var(is, &_nb_classifiers_per_level);
397
398   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
399
400   _classifiers = new Classifier *[_nb_classifiers];
401   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
402   _thresholds = new scalar_t[_nb_classifiers];
403
404   for(int q = 0; q < _nb_classifiers; q++) {
405     _pi_feature_families[q] = new PiFeatureFamily();
406     _pi_feature_families[q]->read(is);
407     _classifiers[q] = read_classifier(is);
408     read_var(is, &_thresholds[q]);
409   }
410
411   _hierarchy = read_hierarchy(is);
412 }
413
414 void Detector::write(ostream *os) {
415   write_var(os, &_nb_levels);
416   write_var(os, &_nb_classifiers_per_level);
417
418   for(int q = 0; q < _nb_classifiers; q++) {
419     _pi_feature_families[q]->write(os);
420     _classifiers[q]->write(os);
421     write_var(os, &_thresholds[q]);
422   }
423
424   _hierarchy->write(os);
425 }