28c29d3389dbd496a39f1578858845a473fb9ec8
[svrt.git] / vision_test.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include <iostream>
26 #include <fstream>
27 #include <cmath>
28 #include <stdio.h>
29 #include <stdlib.h>
30
31 using namespace std;
32
33 #include "rgb_image.h"
34 #include "param_parser.h"
35 #include "global.h"
36
37 #include "vignette.h"
38 #include "shape.h"
39 #include "classifier.h"
40 #include "classifier_reader.h"
41 #include "naive_bayesian_classifier.h"
42 #include "boosted_classifier.h"
43 #include "error_rates.h"
44
45 #include "vision_problem_1.h"
46 #include "vision_problem_2.h"
47 #include "vision_problem_3.h"
48 #include "vision_problem_4.h"
49 #include "vision_problem_5.h"
50 #include "vision_problem_6.h"
51 #include "vision_problem_7.h"
52 #include "vision_problem_8.h"
53 #include "vision_problem_9.h"
54 #include "vision_problem_10.h"
55 #include "vision_problem_11.h"
56 #include "vision_problem_12.h"
57 #include "vision_problem_13.h"
58 #include "vision_problem_14.h"
59 #include "vision_problem_15.h"
60 #include "vision_problem_16.h"
61 #include "vision_problem_17.h"
62 #include "vision_problem_18.h"
63 #include "vision_problem_19.h"
64 #include "vision_problem_20.h"
65 #include "vision_problem_21.h"
66 #include "vision_problem_22.h"
67 #include "vision_problem_23.h"
68
69 //////////////////////////////////////////////////////////////////////
70
71 void check(bool condition, const char *message) {
72   if(!condition) {
73     cerr << message << endl;
74     exit(1);
75   }
76 }
77
78 int main(int argc, char **argv) {
79
80   char buffer[buffer_size];
81   char *new_argv[argc];
82   int new_argc = 0;
83
84   cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
85
86   for(int i = 0; i < argc; i++)
87     cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
88          << endl;
89
90   cout << "-- PARAMETERS --------------------------------------------------------" << endl;
91
92   {
93     ParamParser parser;
94     global.init_parser(&parser);
95     parser.parse_options(argc, argv, false, &new_argc, new_argv);
96     global.read_parser(&parser);
97     parser.print_all(&cout);
98   }
99
100   nice(global.niceness);
101   srand48(global.random_seed);
102
103   VignetteGenerator *generator;
104
105   switch(global.problem_number) {
106   case 1:
107     generator = new VisionProblem_1();
108     break;
109   case 2:
110     generator = new VisionProblem_2();
111     break;
112   case 3:
113     generator = new VisionProblem_3();
114     break;
115   case 4:
116     generator = new VisionProblem_4();
117     break;
118   case 5:
119     generator = new VisionProblem_5();
120     break;
121   case 6:
122     generator = new VisionProblem_6();
123     break;
124   case 7:
125     generator = new VisionProblem_7();
126     break;
127   case 8:
128     generator = new VisionProblem_8();
129     break;
130   case 9:
131     generator = new VisionProblem_9();
132     break;
133   case 10:
134     generator = new VisionProblem_10();
135     break;
136   case 11:
137     generator = new VisionProblem_11();
138     break;
139   case 12:
140     generator = new VisionProblem_12();
141     break;
142   case 13:
143     generator = new VisionProblem_13();
144     break;
145   case 14:
146     generator = new VisionProblem_14();
147     break;
148   case 15:
149     generator = new VisionProblem_15();
150     break;
151   case 16:
152     generator = new VisionProblem_16();
153     break;
154   case 17:
155     generator = new VisionProblem_17();
156     break;
157   case 18:
158     generator = new VisionProblem_18();
159     break;
160   case 19:
161     generator = new VisionProblem_19();
162     break;
163   case 20:
164     generator = new VisionProblem_20();
165     break;
166   case 21:
167     generator = new VisionProblem_21();
168     break;
169   case 22:
170     generator = new VisionProblem_22();
171     break;
172   case 23:
173     generator = new VisionProblem_23();
174     break;
175   default:
176     cerr << "Can not find problem "
177          << global.problem_number
178          << endl;
179     exit(1);
180   }
181
182   generator->precompute();
183
184   //////////////////////////////////////////////////////////////////////
185
186   Vignette *train_samples;
187   int *train_labels;
188
189   train_samples = new Vignette[global.nb_train_samples];
190   train_labels = new int[global.nb_train_samples];
191
192   //////////////////////////////////////////////////////////////////////
193
194   Classifier *classifier = 0;
195
196   cout << "-- COMPUTATIONS ------------------------------------------------------" << endl;
197
198   for(int c = 1; c < new_argc; c++) {
199
200     if(strcmp(new_argv[c], "randomize-train") == 0) {
201       cout << "Generating the training set." << endl;
202       for(int n = 0; n < global.nb_train_samples; n++) {
203         train_labels[n] = int(drand48() * 2);
204         generator->generate(train_labels[n], &train_samples[n]);
205       }
206     }
207
208     else if(strcmp(new_argv[c], "adaboost") == 0) {
209       delete classifier;
210       cout << "Building and training adaboost classifier." << endl;
211       classifier = new BoostedClassifier(global.nb_weak_learners);
212       classifier->train(global.nb_train_samples, train_samples, train_labels);
213     }
214
215     else if(strcmp(new_argv[c], "naive-bayesian") == 0) {
216       delete classifier;
217       cout << "Building and training naive bayesian classifier." << endl;
218       classifier = new NaiveBayesianClassifier();
219       classifier->train(global.nb_train_samples, train_samples, train_labels);
220     }
221
222     else if(strcmp(new_argv[c], "read-classifier") == 0) {
223       delete classifier;
224       sprintf(buffer, "%s", global.classifier_name);
225       cout << "Reading classifier from " << buffer << "." << endl;
226       ifstream in(buffer);
227       if(in.fail()) {
228         cerr << "Can not open " << buffer << " for reading." << endl;
229         exit(1);
230       }
231       classifier = read_classifier(&in);
232     }
233
234     else if(strcmp(new_argv[c], "write-classifier") == 0) {
235       check(classifier, "No classifier.");
236       sprintf(buffer, "%s/%s", global.result_path, global.classifier_name);
237       cout << "Writing classifier to " << buffer << "." << endl;
238       ofstream out(buffer);
239       if(out.fail()) {
240         cerr << "Can not open " << buffer << " for writing." << endl;
241         exit(1);
242       }
243       classifier->write(&out);
244     }
245
246     else if(strcmp(new_argv[c], "compute-errors-vs-nb-samples") == 0) {
247       for(int t = global.nb_train_samples; t >= 100; t /= 10) {
248         for(int n = 0; n < t; n++) {
249           train_labels[n] = int(drand48() * 2);
250           generator->generate(train_labels[n], &train_samples[n]);
251         }
252         Classifier *classifier = 0;
253         cout << "Building and training adaboost classifier with " << t << " samples." << endl;
254         classifier = new BoostedClassifier(global.nb_weak_learners);
255         classifier->train(t, train_samples, train_labels);
256         cout << "ERROR_RATES_VS_NB_SAMPLES "
257              << t
258              << " TRAIN_ERROR "
259              << error_rate(classifier, t, train_samples, train_labels)
260              << " TEST_ERROR "
261              << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
262         delete classifier;
263       }
264     }
265
266     else if(strcmp(new_argv[c], "compute-train-error") == 0) {
267       check(classifier, "No classifier.");
268       cout << "TRAIN_ERROR_RATE "
269            << classifier->name()
270            << " "
271            << error_rate(classifier, global.nb_train_samples, train_samples, train_labels)
272            << endl;
273     }
274
275     else if(strcmp(new_argv[c], "compute-test-error") == 0) {
276       check(classifier, "No classifier.");
277       cout << "TEST_ERROR_RATE "
278            << classifier->name()
279            << " "
280            << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
281     }
282
283     else if(strcmp(new_argv[c], "write-samples") == 0) {
284       Vignette vignette;
285       for(int k = 0; k < global.nb_train_samples; k++) {
286         for(int l = 0; l < 2; l++) {
287           generator->generate(l, &vignette);
288           sprintf(buffer, "%s/sample_%01d_%04d.png", global.result_path, l, k);
289           vignette.write_png(buffer, 1);
290           cout << "Wrote " << buffer << endl;
291         }
292       }
293     }
294
295     //////////////////////////////////////////////////////////////////////
296
297     //////////////////////////////////////////////////////////////////////
298
299     else {
300       cerr << "Unknown action " << new_argv[c] << endl;
301       exit(1);
302     }
303
304   }
305
306   cout << "-- FINISHED ----------------------------------------------------------" << endl;
307
308   delete classifier;
309   delete[] train_labels;
310   delete[] train_samples;
311   delete generator;
312 }