Removed the definition of basename, which confuses an existing system one.
[folded-ctf.git] / detector.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  folded-ctf is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "tools.h"
26 #include "detector.h"
27 #include "global.h"
28 #include "classifier_reader.h"
29 #include "pose_cell_hierarchy_reader.h"
30
31 Detector::Detector() {
32   _hierarchy = 0;
33   _nb_levels = 0;
34   _nb_classifiers_per_level = 0;
35   _thresholds = 0;
36   _nb_classifiers = 0;
37   _classifiers = 0;
38   _pi_feature_families = 0;
39 }
40
41
42 Detector::~Detector() {
43   if(_hierarchy) {
44     delete[] _thresholds;
45     for(int q = 0; q < _nb_classifiers; q++) {
46       delete _classifiers[q];
47       delete _pi_feature_families[q];
48     }
49     delete[] _classifiers;
50     delete[] _pi_feature_families;
51     delete _hierarchy;
52   }
53 }
54
55 //////////////////////////////////////////////////////////////////////
56 // Training
57
58 void Detector::train_classifier(int level,
59                                 LossMachine *loss_machine,
60                                 ParsingPool *parsing_pool,
61                                 PiFeatureFamily *pi_feature_family,
62                                 Classifier *classifier) {
63
64   // Randomize the pi-feature family
65
66   PiFeatureFamily full_pi_feature_family;
67
68   full_pi_feature_family.resize(global.nb_features_for_boosting_optimization);
69   full_pi_feature_family.randomize(level);
70
71   int nb_positives = parsing_pool->nb_positive_cells();
72
73   int nb_negatives_to_sample =
74     parsing_pool->nb_positive_cells() * global.nb_negative_samples_per_positive;
75
76   SampleSet *sample_set = new SampleSet(full_pi_feature_family.nb_features(),
77                                         nb_positives + nb_negatives_to_sample);
78
79   scalar_t *responses = new scalar_t[nb_positives + nb_negatives_to_sample];
80
81   parsing_pool->weighted_sampling(loss_machine,
82                                   &full_pi_feature_family,
83                                   sample_set,
84                                   responses);
85
86   (*global.log_stream) << "Initial train_loss "
87                        << loss_machine->loss(sample_set, responses)
88                        << endl;
89
90   classifier->train(loss_machine, sample_set, responses);
91   classifier->extract_pi_feature_family(&full_pi_feature_family, pi_feature_family);
92
93   delete[] responses;
94   delete sample_set;
95 }
96
97 void Detector::train(LabelledImagePool *train_pool,
98                      LabelledImagePool *validation_pool,
99                      LabelledImagePool *hierarchy_pool) {
100
101   if(_hierarchy) {
102     cerr << "Can not re-train a Detector" << endl;
103     exit(1);
104   }
105
106   _hierarchy = new PoseCellHierarchy(hierarchy_pool);
107
108   int nb_violations;
109
110   nb_violations = _hierarchy->nb_incompatible_poses(train_pool);
111
112   if(nb_violations > 0) {
113     cout << "The hierarchy is incompatible with the training set ("
114          << nb_violations
115          << " violations)." << endl;
116     exit(1);
117   }
118
119   nb_violations = _hierarchy->nb_incompatible_poses(validation_pool);
120
121   if(nb_violations > 0) {
122     cout << "The hierarchy is incompatible with the validation set ("
123          << nb_violations << " violations)."
124          << endl;
125     exit(1);
126   }
127
128   _nb_levels = _hierarchy->nb_levels();
129   _nb_classifiers_per_level = global.nb_classifiers_per_level;
130   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
131   _thresholds = new scalar_t[_nb_classifiers];
132   _classifiers = new Classifier *[_nb_classifiers];
133   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
134
135   for(int q = 0; q < _nb_classifiers; q++) {
136     _classifiers[q] = new BoostedClassifier(global.nb_weak_learners_per_classifier);
137     _pi_feature_families[q] = new PiFeatureFamily();
138   }
139
140   ParsingPool *train_parsing, *validation_parsing;
141
142   train_parsing = new ParsingPool(train_pool,
143                                   _hierarchy,
144                                   global.proportion_negative_cells_for_training);
145
146   if(global.write_validation_rocs) {
147     validation_parsing = new ParsingPool(validation_pool,
148                                          _hierarchy,
149                                          global.proportion_negative_cells_for_training);
150   } else {
151     validation_parsing = 0;
152   }
153
154   LossMachine *loss_machine = new LossMachine(global.loss_type);
155
156   cout << "Building a detector." << endl;
157
158   global.bar.init(&cout, _nb_classifiers);
159
160   for(int l = 0; l < _nb_levels; l++) {
161
162     if(l > 0) {
163       train_parsing->down_one_level(loss_machine, _hierarchy, l);
164       if(validation_parsing) {
165         validation_parsing->down_one_level(loss_machine, _hierarchy, l);
166       }
167     }
168
169     for(int c = 0; c < _nb_classifiers_per_level; c++) {
170       int q = l * _nb_classifiers_per_level + c;
171
172       // Train the classifier
173
174       train_classifier(l,
175                        loss_machine,
176                        train_parsing,
177                        _pi_feature_families[q], _classifiers[q]);
178
179       // Update the cell responses on the training set
180
181       train_parsing->update_cell_responses(_pi_feature_families[q],
182                                            _classifiers[q]);
183
184       // Save the ROC curves on the training set
185
186       char buffer[buffer_size];
187
188       sprintf(buffer, "%s/train_%05d.roc",
189               global.result_path,
190               (q + 1) * global.nb_weak_learners_per_classifier);
191       ofstream out(buffer);
192       train_parsing->write_roc(&out);
193
194       if(validation_parsing) {
195
196         // Update the cell responses on the validation set
197
198         validation_parsing->update_cell_responses(_pi_feature_families[q],
199                                                   _classifiers[q]);
200
201         // Save the ROC curves on the validation set
202
203         sprintf(buffer, "%s/validation_%05d.roc",
204                 global.result_path,
205                 (q + 1) * global.nb_weak_learners_per_classifier);
206         ofstream out(buffer);
207         validation_parsing->write_roc(&out);
208       }
209
210       _thresholds[q] = 0.0;
211
212       global.bar.refresh(&cout, q);
213     }
214   }
215
216   global.bar.finish(&cout);
217
218   delete loss_machine;
219   delete train_parsing;
220   delete validation_parsing;
221 }
222
223 void Detector::compute_thresholds(LabelledImagePool *validation_pool, scalar_t wanted_tp) {
224   LabelledImage *image;
225   int nb_targets_total = 0;
226
227   for(int i = 0; i < validation_pool->nb_images(); i++) {
228     image = validation_pool->grab_image(i);
229     nb_targets_total += image->nb_targets();
230     validation_pool->release_image(i);
231   }
232
233   scalar_t *responses = new scalar_t[_nb_classifiers * nb_targets_total];
234
235   int tt = 0;
236
237   for(int i = 0; i < validation_pool->nb_images(); i++) {
238     image = validation_pool->grab_image(i);
239     image->compute_rich_structure();
240
241     PoseCell current_cell;
242
243     for(int t = 0; t < image->nb_targets(); t++) {
244
245       scalar_t response = 0;
246
247       for(int l = 0; l < _nb_levels; l++) {
248
249         // We get the next-level cell for that target
250
251         PoseCellSet cell_set;
252
253         cell_set.erase_content();
254         if(l == 0) {
255           _hierarchy->add_root_cells(image, &cell_set);
256         } else {
257           _hierarchy->add_subcells(l, &current_cell, &cell_set);
258         }
259
260         int nb_compliant = 0;
261
262         for(int c = 0; c < cell_set.nb_cells(); c++) {
263           if(cell_set.get_cell(c)->contains(image->get_target_pose(t))) {
264             current_cell = *(cell_set.get_cell(c));
265             nb_compliant++;
266           }
267         }
268
269         if(nb_compliant != 1) {
270           cerr << "INCONSISTENCY (" << nb_compliant << " should be one)" << endl;
271           abort();
272         }
273
274         for(int c = 0; c < _nb_classifiers_per_level; c++) {
275           int q = l * _nb_classifiers_per_level + c;
276           SampleSet *sample_set = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
277           sample_set->set_sample(0, _pi_feature_families[q], image, &current_cell, 0);
278           response +=_classifiers[q]->response(sample_set, 0);
279           delete sample_set;
280           responses[tt + nb_targets_total * q] = response;
281         }
282
283       }
284
285       tt++;
286     }
287
288     validation_pool->release_image(i);
289   }
290
291   ASSERT(tt == nb_targets_total);
292
293   // Here we have in responses[] all the target responses after every
294   // classifier
295
296   int *still_detected = new int[nb_targets_total];
297   int *indexes = new int[nb_targets_total];
298   int *sorted_indexes = new int[nb_targets_total];
299
300   for(int t = 0; t < nb_targets_total; t++) {
301     still_detected[t] = 1;
302     indexes[t] = t;
303   }
304
305   int current_nb_fn = 0;
306
307   for(int q = 0; q < _nb_classifiers; q++) {
308
309     scalar_t wanted_tp_at_this_classifier
310       = exp(log(wanted_tp) * scalar_t(q + 1) / scalar_t(_nb_classifiers));
311
312     int wanted_nb_fn_at_this_classifier
313       = int(nb_targets_total * (1 - wanted_tp_at_this_classifier));
314
315     indexed_fusion_sort(nb_targets_total, indexes, sorted_indexes,
316                         responses + q * nb_targets_total);
317
318     for(int t = 0; (current_nb_fn < wanted_nb_fn_at_this_classifier) && (t < nb_targets_total - 1); t++) {
319       int u = sorted_indexes[t];
320       int v = sorted_indexes[t+1];
321       _thresholds[q] = responses[v + nb_targets_total * q];
322       if(still_detected[u]) {
323         still_detected[u] = 0;
324         current_nb_fn++;
325       }
326     }
327   }
328
329   delete[] still_detected;
330   delete[] indexes;
331   delete[] sorted_indexes;
332   delete[] responses;
333 }
334
335 //////////////////////////////////////////////////////////////////////
336 // Parsing
337
338 void Detector::parse_rec(RichImage *image, int level,
339                          PoseCell *cell, scalar_t current_response,
340                          PoseCellScoredSet *result) {
341
342   if(level == _nb_levels) {
343     result->add_cell_with_score(cell, current_response);
344     return;
345   }
346
347   PoseCellSet cell_set;
348   cell_set.erase_content();
349
350   if(level == 0) {
351     _hierarchy->add_root_cells(image, &cell_set);
352   } else {
353     _hierarchy->add_subcells(level, cell, &cell_set);
354   }
355
356   scalar_t *responses = new scalar_t[cell_set.nb_cells()];
357   int *keep = new int[cell_set.nb_cells()];
358
359   for(int c = 0; c < cell_set.nb_cells(); c++) {
360     responses[c] = current_response;
361     keep[c] = 1;
362   }
363
364   for(int a = 0; a < _nb_classifiers_per_level; a++) {
365     int q = level * _nb_classifiers_per_level + a;
366     SampleSet *samples = new SampleSet(_pi_feature_families[q]->nb_features(), 1);
367     for(int c = 0; c < cell_set.nb_cells(); c++) {
368       if(keep[c]) {
369         samples->set_sample(0, _pi_feature_families[q], image, cell_set.get_cell(c), 0);
370         responses[c] += _classifiers[q]->response(samples, 0);
371         keep[c] = responses[c] >= _thresholds[q];
372       }
373     }
374     delete samples;
375   }
376
377   for(int c = 0; c < cell_set.nb_cells(); c++) {
378     if(keep[c]) {
379       parse_rec(image, level + 1, cell_set.get_cell(c), responses[c], result);
380     }
381   }
382
383   delete[] keep;
384   delete[] responses;
385 }
386
387 void Detector::parse(RichImage *image, PoseCellScoredSet *result_cell_set) {
388   result_cell_set->erase_content();
389   parse_rec(image, 0, 0, 0, result_cell_set);
390 }
391
392 //////////////////////////////////////////////////////////////////////
393 // Storage
394
395 void Detector::read(istream *is) {
396   if(_hierarchy) {
397     cerr << "Can not read over an existing Detector" << endl;
398     exit(1);
399   }
400
401   read_var(is, &_nb_levels);
402   read_var(is, &_nb_classifiers_per_level);
403
404   _nb_classifiers = _nb_levels * _nb_classifiers_per_level;
405
406   _classifiers = new Classifier *[_nb_classifiers];
407   _pi_feature_families = new PiFeatureFamily *[_nb_classifiers];
408   _thresholds = new scalar_t[_nb_classifiers];
409
410   for(int q = 0; q < _nb_classifiers; q++) {
411     _pi_feature_families[q] = new PiFeatureFamily();
412     _pi_feature_families[q]->read(is);
413     _classifiers[q] = read_classifier(is);
414     read_var(is, &_thresholds[q]);
415   }
416
417   _hierarchy = read_hierarchy(is);
418 }
419
420 void Detector::write(ostream *os) {
421   write_var(os, &_nb_levels);
422   write_var(os, &_nb_classifiers_per_level);
423
424   for(int q = 0; q < _nb_classifiers; q++) {
425     _pi_feature_families[q]->write(os);
426     _classifiers[q]->write(os);
427     write_var(os, &_thresholds[q]);
428   }
429
430   _hierarchy->write(os);
431 }