94591f19b3bf2d9cbaef322a26a92aeb2f2c1cac
[svrt.git] / vision_test.cc
1 c/*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include <iostream>
26 #include <fstream>
27 #include <cmath>
28 #include <stdio.h>
29 #include <stdlib.h>
30
31 using namespace std;
32
33 #include "rgb_image.h"
34 #include "param_parser.h"
35 #include "global.h"
36
37 #include "vignette.h"
38 #include "shape.h"
39 #include "classifier.h"
40 #include "classifier_reader.h"
41 #include "naive_bayesian_classifier.h"
42 #include "boosted_classifier.h"
43 #include "error_rates.h"
44
45 #include "vision_problem_1.h"
46 #include "vision_problem_2.h"
47 #include "vision_problem_3.h"
48 #include "vision_problem_5.h"
49 #include "vision_problem_6.h"
50 #include "vision_problem_8.h"
51 #include "vision_problem_11.h"
52 #include "vision_problem_12.h"
53 #include "vision_problem_13.h"
54 #include "vision_problem_17.h"
55 #include "vision_problem_18.h"
56 #include "vision_problem_20.h"
57 #include "vision_problem_21.h"
58
59 //////////////////////////////////////////////////////////////////////
60
61 void check(bool condition, const char *message) {
62   if(!condition) {
63     cerr << message << endl;
64     exit(1);
65   }
66 }
67
68 int main(int argc, char **argv) {
69
70   char buffer[buffer_size];
71   char *new_argv[argc];
72   int new_argc = 0;
73
74   cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
75
76   for(int i = 0; i < argc; i++)
77     cout << (i > 0 ? "  " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
78          << endl;
79
80   cout << "-- PARAMETERS --------------------------------------------------------" << endl;
81
82   {
83     ParamParser parser;
84     global.init_parser(&parser);
85     parser.parse_options(argc, argv, false, &new_argc, new_argv);
86     global.read_parser(&parser);
87     parser.print_all(&cout);
88   }
89
90   nice(global.niceness);
91   srand48(global.random_seed);
92
93   VignetteGenerator *generator;
94
95   switch(global.problem_number) {
96   case 1:
97     generator = new VisionProblem_1();
98     break;
99   case 2:
100     generator = new VisionProblem_2();
101     break;
102   case 3:
103     generator = new VisionProblem_3();
104     break;
105   // case 4:
106     // generator = new VisionProblem_4();
107     // break;
108   case 5:
109     generator = new VisionProblem_5();
110     break;
111   case 6:
112     generator = new VisionProblem_6();
113     break;
114   // case 7:
115     // generator = new VisionProblem_7();
116     // break;
117   case 8:
118     generator = new VisionProblem_8();
119     break;
120   // case 9:
121     // generator = new VisionProblem_9();
122     // break;
123   // case 10:
124     // generator = new VisionProblem_10();
125     // break;
126   case 11:
127     generator = new VisionProblem_11();
128     break;
129   case 12:
130     generator = new VisionProblem_12();
131     break;
132   case 13:
133     generator = new VisionProblem_13();
134     break;
135   // case 14:
136     // generator = new VisionProblem_14();
137     // break;
138   // case 15:
139     // generator = new VisionProblem_15();
140     // break;
141   // case 16:
142     // generator = new VisionProblem_16();
143     // break;
144   case 17:
145     generator = new VisionProblem_17();
146     break;
147   case 18:
148     generator = new VisionProblem_18();
149     break;
150   // case 19:
151     // generator = new VisionProblem_19();
152     // break;
153   case 20:
154     generator = new VisionProblem_20();
155     break;
156   case 21:
157     generator = new VisionProblem_21();
158     break;
159   // case 22:
160     // generator = new VisionProblem_22();
161     // break;
162   // case 23:
163     // generator = new VisionProblem_23();
164     // break;
165   default:
166     cerr << "Can not find problem "
167          << global.problem_number
168          << endl;
169     exit(1);
170   }
171
172   generator->precompute();
173
174   //////////////////////////////////////////////////////////////////////
175
176   Vignette *train_samples;
177   int *train_labels;
178
179   train_samples = new Vignette[global.nb_train_samples];
180   train_labels = new int[global.nb_train_samples];
181
182   //////////////////////////////////////////////////////////////////////
183
184   Classifier *classifier = 0;
185
186   cout << "-- COMPUTATIONS ------------------------------------------------------" << endl;
187
188   for(int c = 1; c < new_argc; c++) {
189
190     if(strcmp(new_argv[c], "randomize-train") == 0) {
191       for(int n = 0; n < global.nb_train_samples; n++) {
192         if(n > 0 && n%100 == 0) cout << "*"; cout.flush();
193         // cout << "n = " << n << endl;
194         train_labels[n] = int(drand48() * 2);
195         generator->generate(train_labels[n], &train_samples[n]);
196       }
197       cout << endl;
198     }
199
200     else if(strcmp(new_argv[c], "adaboost") == 0) {
201       delete classifier;
202       cout << "Building and training adaboost classifier." << endl;
203       classifier = new BoostedClassifier(global.nb_weak_learners);
204       classifier->train(global.nb_train_samples, train_samples, train_labels);
205     }
206
207     else if(strcmp(new_argv[c], "naive-bayesian") == 0) {
208       delete classifier;
209       cout << "Building and training naive bayesian classifier." << endl;
210       classifier = new NaiveBayesianClassifier();
211       classifier->train(global.nb_train_samples, train_samples, train_labels);
212     }
213
214     else if(strcmp(new_argv[c], "read-classifier") == 0) {
215       delete classifier;
216       sprintf(buffer, "%s", global.classifier_name);
217       cout << "Reading classifier from " << buffer << "." << endl;
218       ifstream in(buffer);
219       if(in.fail()) {
220         cerr << "Can not open " << buffer << " for reading." << endl;
221         exit(1);
222       }
223       classifier = read_classifier(&in);
224     }
225
226     else if(strcmp(new_argv[c], "write-classifier") == 0) {
227       check(classifier, "No classifier.");
228       sprintf(buffer, "%s/%s", global.result_path, global.classifier_name);
229       cout << "Writing classifier to " << buffer << "." << endl;
230       ofstream out(buffer);
231       if(out.fail()) {
232         cerr << "Can not open " << buffer << " for writing." << endl;
233         exit(1);
234       }
235       classifier->write(&out);
236     }
237
238     else if(strcmp(new_argv[c], "compute-errors-vs-nb-samples") == 0) {
239       for(int t = global.nb_train_samples; t >= 100; t /= 10) {
240         for(int n = 0; n < t; n++) {
241           train_labels[n] = int(drand48() * 2);
242           generator->generate(train_labels[n], &train_samples[n]);
243         }
244         Classifier *classifier = 0;
245         cout << "Building and training adaboost classifier with " << t << " samples." << endl;
246         classifier = new BoostedClassifier(global.nb_weak_learners);
247         classifier->train(t, train_samples, train_labels);
248         cout << "ERROR_RATES_VS_NB_SAMPLES "
249              << t
250              << " TRAIN_ERROR "
251              << error_rate(classifier, t, train_samples, train_labels)
252              << " TEST_ERROR "
253              << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
254         delete classifier;
255       }
256     }
257
258     else if(strcmp(new_argv[c], "compute-train-error") == 0) {
259       check(classifier, "No classifier.");
260       cout << "TRAIN_ERROR_RATE "
261            << classifier->name()
262            << " "
263            << error_rate(classifier, global.nb_train_samples, train_samples, train_labels)
264            << endl;
265     }
266
267     else if(strcmp(new_argv[c], "compute-test-error") == 0) {
268       check(classifier, "No classifier.");
269       cout << "TEST_ERROR_RATE "
270            << classifier->name()
271            << " "
272            << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
273     }
274
275     else if(strcmp(new_argv[c], "write-samples") == 0) {
276       Vignette vignette;
277       for(int k = 0; k < global.nb_train_samples; k++) {
278         for(int l = 0; l < 2; l++) {
279           generator->generate(l, &vignette);
280           sprintf(buffer, "%s/sample_%01d_%04d.png", global.result_path, l, k);
281           vignette.write_png(buffer, 1);
282           cout << "Wrote " << buffer << endl;
283         }
284       }
285     }
286
287     //////////////////////////////////////////////////////////////////////
288
289     //////////////////////////////////////////////////////////////////////
290
291     else {
292       cerr << "Unknown action " << new_argv[c] << endl;
293       exit(1);
294     }
295
296   }
297
298   cout << "-- FINISHED ----------------------------------------------------------" << endl;
299
300   delete classifier;
301   delete[] train_labels;
302   delete[] train_samples;
303   delete generator;
304 }