automatic commit
[folded-ctf.git] / folding.cc
1
2 ///////////////////////////////////////////////////////////////////////////
3 // This program is free software: you can redistribute it and/or modify  //
4 // it under the terms of the version 3 of the GNU General Public License //
5 // as published by the Free Software Foundation.                         //
6 //                                                                       //
7 // This program is distributed in the hope that it will be useful, but   //
8 // WITHOUT ANY WARRANTY; without even the implied warranty of            //
9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU      //
10 // General Public License for more details.                              //
11 //                                                                       //
12 // You should have received a copy of the GNU General Public License     //
13 // along with this program. If not, see <http://www.gnu.org/licenses/>.  //
14 //                                                                       //
15 // Written by Francois Fleuret, (C) IDIAP                                //
16 // Contact <francois.fleuret@idiap.ch> for comments & bug reports        //
17 ///////////////////////////////////////////////////////////////////////////
18
19 #include <iostream>
20 #include <fstream>
21 #include <cmath>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25
26 using namespace std;
27
28 #include "misc.h"
29 #include "param_parser.h"
30 #include "global.h"
31 #include "labelled_image_pool_file.h"
32 #include "labelled_image_pool_subset.h"
33 #include "tools.h"
34 #include "detector.h"
35 #include "pose_cell_hierarchy.h"
36 #include "error_rates.h"
37 #include "materials.h"
38
39 //////////////////////////////////////////////////////////////////////
40
41 void check(bool condition, const char *message) {
42   if(!condition) {
43     cerr << message << endl;
44     exit(1);
45   }
46 }
47
48 //////////////////////////////////////////////////////////////////////
49
50 int main(int argc, char **argv) {
51   char *new_argv[argc];
52   int new_argc = 0;
53
54 #ifdef DEBUG
55   cout << endl;
56   cout << "**********************************************************************" << endl;
57   cout << "**                     COMPILED IN DEBUG MODE                       **" << endl;
58   cout << "**********************************************************************" << endl;
59   cout << endl;
60 #endif
61
62   cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
63   for(int i = 0; i < argc; i++)
64     cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
65          << endl;
66
67   {
68     ParamParser parser;
69     global.init_parser(&parser);
70     parser.parse_options(argc, argv, false, &new_argc, new_argv);
71     global.read_parser(&parser);
72     (*global.log_stream)
73       << "-- PARAMETERS --------------------------------------------------------"
74       << endl;
75     parser.print_all(global.log_stream);
76   }
77
78   nice(global.niceness);
79
80   (*global.log_stream) << "INFO RANDOM_SEED " << global.random_seed << endl;
81   srand48(global.random_seed);
82
83   LabelledImagePool *main_pool = 0;
84   LabelledImagePool *train_pool = 0, *validation_pool = 0, *hierarchy_pool = 0;
85   LabelledImagePool *test_pool = 0;
86   Detector *detector = 0;
87
88   {
89     char buffer[buffer_size];
90     gethostname(buffer, buffer_size);
91     (*global.log_stream) << "INFO HOSTNAME " << buffer << endl;
92   }
93
94   for(int c = 1; c < new_argc; c++) {
95
96     if(strcmp(new_argv[c], "open-pool") == 0) {
97       cout
98         << "-- OPENING POOL ------------------------------------------------------"
99         << endl;
100
101       check(!main_pool, "Pool already opened.");
102       check(global.pool_name[0], "No pool file.");
103
104       main_pool = new LabelledImagePoolFile(global.pool_name);
105
106       bool for_test[main_pool->nb_images()];
107       bool for_train[main_pool->nb_images()];
108       bool for_validation[main_pool->nb_images()];
109       bool for_hierarchy[main_pool->nb_images()];
110
111       for(int n = 0; n < main_pool->nb_images(); n++) {
112         for_test[n] = false;
113         for_train[n] = false;
114         for_validation[n] = false;
115         scalar_t r = drand48();
116         if(r < global.proportion_for_train)
117           for_train[n] = true;
118         else if(r < global.proportion_for_train + global.proportion_for_validation)
119           for_validation[n] = true;
120         else if(global.proportion_for_test < 0 ||
121                 r < global.proportion_for_train +
122                 global.proportion_for_validation +
123                 global.proportion_for_test)
124           for_test[n] = true;
125         for_hierarchy[n] = for_train[n] || for_validation[n];
126       }
127
128       train_pool = new LabelledImagePoolSubset(main_pool, for_train);
129       validation_pool = new LabelledImagePoolSubset(main_pool, for_validation);
130       hierarchy_pool = new LabelledImagePoolSubset(main_pool, for_hierarchy);
131
132       if(global.test_pool_name[0]) {
133         test_pool = new LabelledImagePoolFile(global.test_pool_name);
134       } else {
135         test_pool = new LabelledImagePoolSubset(main_pool, for_test);
136       }
137
138       cout << "Using "
139            << train_pool->nb_images() << " images for train, "
140            << validation_pool->nb_images() << " images for validation, "
141            << hierarchy_pool->nb_images() << " images for the hierarchy and "
142            << test_pool->nb_images() << " images for test."
143            << endl;
144
145     }
146
147     else if(strcmp(new_argv[c], "write-target-poses") == 0) {
148       check(main_pool, "No pool available.");
149       LabelledImage *image;
150       for(int p = 0; p < main_pool->nb_images(); p++) {
151         image = main_pool->grab_image(p);
152         for(int t = 0; t < image->nb_targets(); t++) {
153           cout << "IMAGE " << p << " TARGET " << t << endl;
154           image->get_target_pose(t)->print(&cout);
155         }
156         main_pool->release_image(p);
157       }
158     }
159
160     //////////////////////////////////////////////////////////////////////
161
162     else if(strcmp(new_argv[c], "train-detector") == 0) {
163       cout << "-- TRAIN DETECTOR ----------------------------------------------------" << endl;
164       check(train_pool, "No train pool available.");
165       check(validation_pool, "No validation pool available.");
166       check(hierarchy_pool, "No hierarchy pool available.");
167       check(!detector, "Existing detector, can not train another one.");
168       detector = new Detector();
169       detector->train(train_pool, validation_pool, hierarchy_pool);
170     }
171
172     else if(strcmp(new_argv[c], "compute-thresholds") == 0) {
173       cout << "-- COMPUTE THRESHOLDS ------------------------------------------------" << endl;
174       check(validation_pool, "No validation pool available.");
175       check(detector, "No detector.");
176       detector->compute_thresholds(validation_pool, global.wanted_true_positive_rate);
177     }
178
179     else if(strcmp(new_argv[c], "check-hierarchy") == 0) {
180       cout << "-- CHECK HIERARCHY ---------------------------------------------------" << endl;
181       PoseCellHierarchy *h = new PoseCellHierarchy(hierarchy_pool);
182       cout << "Train incompatible poses " << h->nb_incompatible_poses(train_pool) << endl;
183       cout << "Validation incompatible poses " << h->nb_incompatible_poses(validation_pool) << endl;
184       delete h;
185     }
186
187     //////////////////////////////////////////////////////////////////////
188
189     else if(strcmp(new_argv[c], "validate-detector") == 0) {
190       cout << "-- VALIDATE DETECTOR -------------------------------------------------" << endl;
191
192       check(validation_pool, "No validation pool available.");
193       check(detector, "No detector.");
194
195       print_decimated_error_rate(global.nb_levels - 1, validation_pool, detector);
196     }
197
198     //////////////////////////////////////////////////////////////////////
199
200     else if(strcmp(new_argv[c], "test-detector") == 0) {
201       cout << "-- TEST DETECTOR -----------------------------------------------------" << endl;
202
203       check(test_pool, "No test pool available.");
204       check(detector, "No detector.");
205
206       if(test_pool->nb_images() > 0) {
207         print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
208       } else {
209         cout << "No test image." << endl;
210       }
211     }
212
213     else if(strcmp(new_argv[c], "parse-images") == 0) {
214       cout << "-- PARSING IMAGES -----------------------------------------------------" << endl;
215       check(detector, "No detector.");
216       while(!cin.eof()) {
217         char image_name[buffer_size];
218         cin.getline(image_name, buffer_size);
219         if(strlen(image_name) > 0) {
220           parse_scene(detector, image_name);
221         }
222       }
223     }
224
225     //////////////////////////////////////////////////////////////////////
226
227     else if(strcmp(new_argv[c], "sequence-test-detector") == 0) {
228       cout << "-- SEQUENCE TEST DETECTOR --------------------------------------------" << endl;
229
230       check(test_pool, "No test pool available.");
231       check(detector, "No detector.");
232
233       if(test_pool->nb_images() > 0) {
234
235         for(int n = 0; n < global.nb_wanted_true_positive_rates; n++) {
236           scalar_t r = global.wanted_true_positive_rate *
237             scalar_t(n + 1) / scalar_t(global.nb_wanted_true_positive_rates);
238           cout << "Testint at tp " << r
239                << " (" << n + 1 << "/" << global.nb_wanted_true_positive_rates << ")"
240                << endl;
241           (*global.log_stream) << "INFO THRESHOLD_FOR_TP " << r << endl;
242           detector->compute_thresholds(validation_pool, r);
243           print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
244         }
245       } else {
246         cout << "No test image." << endl;
247       }
248     }
249
250     //////////////////////////////////////////////////////////////////////
251
252     else if(strcmp(new_argv[c], "write-detector") == 0) {
253       cout << "-- WRITE DETECTOR ----------------------------------------------------" << endl;
254       ofstream out(global.detector_name);
255       if(out.fail()) {
256         cerr << "Can not write to " << global.detector_name << endl;
257         exit(1);
258       }
259       check(detector, "No detector available.");
260       detector->write(&out);
261     }
262
263     //////////////////////////////////////////////////////////////////////
264
265     else if(strcmp(new_argv[c], "read-detector") == 0) {
266       cout << "-- READ DETECTOR -----------------------------------------------------" << endl;
267
268       check(!detector, "Existing detector, can not load another one.");
269
270       ifstream in(global.detector_name);
271       if(in.fail()) {
272         cerr << "Can not read from " << global.detector_name << endl;
273         exit(1);
274       }
275
276       detector = new Detector();
277       detector->read(&in);
278     }
279
280     //////////////////////////////////////////////////////////////////////
281
282     else if(strcmp(new_argv[c], "write-pool-images") == 0) {
283       cout << "-- WRITING POOL IMAGES -----------------------------------------------" << endl;
284       check(global.nb_images > 0, "You must set nb_images to a positive value.");
285       check(train_pool, "No train pool available.");
286       write_pool_images_with_poses_and_referentials(train_pool, detector);
287     }
288
289     else if(strcmp(new_argv[c], "produce-materials") == 0) {
290       cout << "-- PRODUCING MATERIALS -----------------------------------------------" << endl;
291
292       check(hierarchy_pool, "No hierarchy pool available.");
293       check(test_pool, "No test pool available.");
294
295       PoseCellHierarchy *hierarchy;
296
297       cout << "Creating hierarchy" << endl;
298
299       hierarchy = new PoseCellHierarchy(hierarchy_pool);
300
301       LabelledImage *image;
302       for(int p = 0; p < test_pool->nb_images(); p++) {
303         image = test_pool->grab_image(p);
304         if(image->width() == 640 && image->height() == 480) {
305           PoseCellSet pcs;
306           hierarchy->add_root_cells(image, &pcs);
307           cout << "WE HAVE " << pcs.nb_cells() << " CELLS" << endl;
308           exit(0);
309           test_pool->release_image(p);
310         }
311       }
312
313       delete hierarchy;
314
315     }
316
317     //////////////////////////////////////////////////////////////////////
318
319     else {
320       cerr << "Unknown action " << new_argv[c] << endl;
321       exit(1);
322     }
323
324     //////////////////////////////////////////////////////////////////////
325
326   }
327
328   delete detector;
329
330   delete train_pool;
331   delete validation_pool;
332   delete hierarchy_pool;
333   delete test_pool;
334
335   delete main_pool;
336
337   cout << "-- FINISHED ----------------------------------------------------------" << endl;
338
339 }