automatic commit
[folded-ctf.git] / folding.cc
1 /*
2  *  folded-ctf is an implementation of the folded hierarchy of
3  *  classifiers for object detection, developed by Francois Fleuret
4  *  and Donald Geman.
5  *
6  *  Copyright (c) 2008 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of folded-ctf.
10  *
11  *  folded-ctf is free software: you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published
13  *  by the Free Software Foundation, either version 3 of the License,
14  *  or (at your option) any later version.
15  *
16  *  folded-ctf is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with folded-ctf.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 #include <iostream>
27 #include <fstream>
28 #include <cmath>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32
33 using namespace std;
34
35 #include "misc.h"
36 #include "param_parser.h"
37 #include "global.h"
38 #include "labelled_image_pool_file.h"
39 #include "labelled_image_pool_subset.h"
40 #include "tools.h"
41 #include "detector.h"
42 #include "pose_cell_hierarchy.h"
43 #include "error_rates.h"
44 #include "materials.h"
45
46 //////////////////////////////////////////////////////////////////////
47
48 void check(bool condition, const char *message) {
49   if(!condition) {
50     cerr << message << endl;
51     exit(1);
52   }
53 }
54
55 //////////////////////////////////////////////////////////////////////
56
57 int main(int argc, char **argv) {
58
59 #ifdef DEBUG
60   cout << endl;
61   cout << "**********************************************************************" << endl;
62   cout << "**                     COMPILED IN DEBUG MODE                       **" << endl;
63   cout << "**********************************************************************" << endl;
64   cout << endl;
65 #endif
66
67   char *new_argv[argc];
68   int new_argc = 0;
69
70   cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
71
72   for(int i = 0; i < argc; i++)
73     cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
74          << endl;
75
76   {
77     ParamParser parser;
78     global.init_parser(&parser);
79     parser.parse_options(argc, argv, false, &new_argc, new_argv);
80     global.read_parser(&parser);
81     (*global.log_stream)
82       << "-- PARAMETERS --------------------------------------------------------"
83       << endl;
84     parser.print_all(global.log_stream);
85   }
86
87   nice(global.niceness);
88
89   (*global.log_stream) << "INFO RANDOM_SEED " << global.random_seed << endl;
90   srand48(global.random_seed);
91
92   LabelledImagePool *main_pool = 0;
93   LabelledImagePool *train_pool = 0, *validation_pool = 0, *hierarchy_pool = 0;
94   LabelledImagePool *test_pool = 0;
95   Detector *detector = 0;
96
97   {
98     char buffer[buffer_size];
99     gethostname(buffer, buffer_size);
100     (*global.log_stream) << "INFO HOSTNAME " << buffer << endl;
101   }
102
103   for(int c = 1; c < new_argc; c++) {
104
105     if(strcmp(new_argv[c], "open-pool") == 0) {
106       cout
107         << "-- OPENING POOL ------------------------------------------------------"
108         << endl;
109
110       check(!main_pool, "Pool already opened.");
111       check(global.pool_name[0], "No pool file.");
112
113       main_pool = new LabelledImagePoolFile(global.pool_name);
114
115       bool for_test[main_pool->nb_images()];
116       bool for_train[main_pool->nb_images()];
117       bool for_validation[main_pool->nb_images()];
118       bool for_hierarchy[main_pool->nb_images()];
119
120       for(int n = 0; n < main_pool->nb_images(); n++) {
121         for_test[n] = false;
122         for_train[n] = false;
123         for_validation[n] = false;
124         scalar_t r = drand48();
125         if(r < global.proportion_for_train)
126           for_train[n] = true;
127         else if(r < global.proportion_for_train + global.proportion_for_validation)
128           for_validation[n] = true;
129         else if(global.proportion_for_test < 0 ||
130                 r < global.proportion_for_train +
131                 global.proportion_for_validation +
132                 global.proportion_for_test)
133           for_test[n] = true;
134         for_hierarchy[n] = for_train[n] || for_validation[n];
135       }
136
137       train_pool = new LabelledImagePoolSubset(main_pool, for_train);
138       validation_pool = new LabelledImagePoolSubset(main_pool, for_validation);
139       hierarchy_pool = new LabelledImagePoolSubset(main_pool, for_hierarchy);
140
141       if(global.test_pool_name[0]) {
142         test_pool = new LabelledImagePoolFile(global.test_pool_name);
143       } else {
144         test_pool = new LabelledImagePoolSubset(main_pool, for_test);
145       }
146
147       cout << "Using "
148            << train_pool->nb_images() << " images for train, "
149            << validation_pool->nb_images() << " images for validation, "
150            << hierarchy_pool->nb_images() << " images for the hierarchy and "
151            << test_pool->nb_images() << " images for test."
152            << endl;
153
154     }
155
156     //////////////////////////////////////////////////////////////////////
157
158     else if(strcmp(new_argv[c], "train-detector") == 0) {
159       cout << "-- TRAIN DETECTOR ----------------------------------------------------" << endl;
160       check(train_pool, "No train pool available.");
161       check(validation_pool, "No validation pool available.");
162       check(hierarchy_pool, "No hierarchy pool available.");
163       check(!detector, "Existing detector, can not train another one.");
164       detector = new Detector();
165       detector->train(train_pool, validation_pool, hierarchy_pool);
166     }
167
168     else if(strcmp(new_argv[c], "compute-thresholds") == 0) {
169       cout << "-- COMPUTE THRESHOLDS ------------------------------------------------" << endl;
170       check(validation_pool, "No validation pool available.");
171       check(detector, "No detector.");
172       detector->compute_thresholds(validation_pool, global.wanted_true_positive_rate);
173     }
174
175     //////////////////////////////////////////////////////////////////////
176
177     else if(strcmp(new_argv[c], "test-detector") == 0) {
178       cout << "-- TEST DETECTOR -----------------------------------------------------" << endl;
179
180       check(test_pool, "No test pool available.");
181       check(detector, "No detector.");
182
183       if(test_pool->nb_images() > 0) {
184         print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
185       } else {
186         cout << "No test image." << endl;
187       }
188     }
189
190     else if(strcmp(new_argv[c], "sequence-test-detector") == 0) {
191       cout << "-- SEQUENCE TEST DETECTOR --------------------------------------------" << endl;
192
193       check(test_pool, "No test pool available.");
194       check(detector, "No detector.");
195
196       if(test_pool->nb_images() > 0) {
197
198         for(int n = 0; n < global.nb_wanted_true_positive_rates; n++) {
199           scalar_t r = global.wanted_true_positive_rate *
200             scalar_t(n + 1) / scalar_t(global.nb_wanted_true_positive_rates);
201           cout << "Testint at tp " << r
202                << " (" << n + 1 << "/" << global.nb_wanted_true_positive_rates << ")"
203                << endl;
204           (*global.log_stream) << "INFO THRESHOLD_FOR_TP " << r << endl;
205           detector->compute_thresholds(validation_pool, r);
206           print_decimated_error_rate(global.nb_levels - 1, test_pool, detector);
207         }
208       } else {
209         cout << "No test image." << endl;
210       }
211     }
212
213     //////////////////////////////////////////////////////////////////////
214
215     else if(strcmp(new_argv[c], "write-detector") == 0) {
216       cout << "-- WRITE DETECTOR ----------------------------------------------------" << endl;
217       ofstream out(global.detector_name);
218       if(out.fail()) {
219         cerr << "Can not write to " << global.detector_name << endl;
220         exit(1);
221       }
222       check(detector, "No detector available.");
223       detector->write(&out);
224     }
225
226     else if(strcmp(new_argv[c], "read-detector") == 0) {
227       cout << "-- READ DETECTOR -----------------------------------------------------" << endl;
228
229       check(!detector, "Existing detector, can not load another one.");
230
231       ifstream in(global.detector_name);
232       if(in.fail()) {
233         cerr << "Can not read from " << global.detector_name << endl;
234         exit(1);
235       }
236
237       detector = new Detector();
238       detector->read(&in);
239     }
240
241     //////////////////////////////////////////////////////////////////////
242
243     else if(strcmp(new_argv[c], "write-pool-images") == 0) {
244       cout << "-- WRITING POOL IMAGES -----------------------------------------------" << endl;
245       check(global.nb_images > 0, "You must set nb_images to a positive value.");
246       check(train_pool, "No train pool available.");
247       write_pool_images_with_poses_and_referentials(train_pool, detector);
248     }
249
250     //////////////////////////////////////////////////////////////////////
251
252     else {
253       cerr << "Unknown action " << new_argv[c] << endl;
254       exit(1);
255     }
256
257     //////////////////////////////////////////////////////////////////////
258
259   }
260
261   delete detector;
262
263   delete train_pool;
264   delete validation_pool;
265   delete hierarchy_pool;
266   delete test_pool;
267
268   delete main_pool;
269
270   cout << "-- FINISHED ----------------------------------------------------------" << endl;
271
272 }