2 * svrt is the ``Synthetic Visual Reasoning Test'', an image
3 * generator for evaluating classification performance of machine
4 * learning systems, humans and primates.
6 * Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7 * Written by Francois Fleuret <francois.fleuret@idiap.ch>
9 * This file is part of svrt.
11 * svrt is free software: you can redistribute it and/or modify it
12 * under the terms of the GNU General Public License version 3 as
13 * published by the Free Software Foundation.
15 * svrt is distributed in the hope that it will be useful, but
16 * WITHOUT ANY WARRANTY; without even the implied warranty of
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 * General Public License for more details.
20 * You should have received a copy of the GNU General Public License
21 * along with svrt. If not, see <http://www.gnu.org/licenses/>.
34 #include "rgb_image.h"
35 #include "param_parser.h"
40 #include "classifier.h"
41 #include "classifier_reader.h"
42 #include "naive_bayesian_classifier.h"
43 #include "boosted_classifier.h"
44 #include "error_rates.h"
46 #include "vision_problem_1.h"
47 #include "vision_problem_2.h"
48 #include "vision_problem_3.h"
49 #include "vision_problem_4.h"
50 #include "vision_problem_5.h"
51 #include "vision_problem_6.h"
52 #include "vision_problem_7.h"
53 #include "vision_problem_8.h"
54 #include "vision_problem_9.h"
55 #include "vision_problem_10.h"
56 #include "vision_problem_11.h"
57 #include "vision_problem_12.h"
58 #include "vision_problem_13.h"
59 #include "vision_problem_14.h"
60 #include "vision_problem_15.h"
61 #include "vision_problem_16.h"
62 #include "vision_problem_17.h"
63 #include "vision_problem_18.h"
64 #include "vision_problem_19.h"
65 #include "vision_problem_20.h"
66 #include "vision_problem_21.h"
67 #include "vision_problem_22.h"
68 #include "vision_problem_23.h"
70 //////////////////////////////////////////////////////////////////////
72 void check(bool condition, const char *message) {
74 cerr << message << endl;
79 int main(int argc, char **argv) {
81 char buffer[buffer_size];
85 cout << "-- ARGUMENTS ---------------------------------------------------------" << endl;
87 for(int i = 0; i < argc; i++)
88 cout << (i > 0 ? " " : "") << argv[i] << (i < argc - 1 ? " \\" : "")
91 cout << "-- PARAMETERS --------------------------------------------------------" << endl;
95 global.init_parser(&parser);
96 parser.parse_options(argc, argv, false, &new_argc, new_argv);
97 global.read_parser(&parser);
98 parser.print_all(&cout);
101 nice(global.niceness);
102 srand48(global.random_seed);
104 VignetteGenerator *generator;
106 switch(global.problem_number) {
108 generator = new VisionProblem_1();
111 generator = new VisionProblem_2();
114 generator = new VisionProblem_3();
117 generator = new VisionProblem_4();
120 generator = new VisionProblem_5();
123 generator = new VisionProblem_6();
126 generator = new VisionProblem_7();
129 generator = new VisionProblem_8();
132 generator = new VisionProblem_9();
135 generator = new VisionProblem_10();
138 generator = new VisionProblem_11();
141 generator = new VisionProblem_12();
144 generator = new VisionProblem_13();
147 generator = new VisionProblem_14();
150 generator = new VisionProblem_15();
153 generator = new VisionProblem_16();
156 generator = new VisionProblem_17();
159 generator = new VisionProblem_18();
162 generator = new VisionProblem_19();
165 generator = new VisionProblem_20();
168 generator = new VisionProblem_21();
171 generator = new VisionProblem_22();
174 generator = new VisionProblem_23();
177 cerr << "Can not find problem "
178 << global.problem_number
183 generator->precompute();
185 //////////////////////////////////////////////////////////////////////
187 Vignette *train_samples;
190 train_samples = new Vignette[global.nb_train_samples];
191 train_labels = new int[global.nb_train_samples];
193 //////////////////////////////////////////////////////////////////////
195 Classifier *classifier = 0;
197 cout << "-- COMPUTATIONS ------------------------------------------------------" << endl;
199 for(int c = 1; c < new_argc; c++) {
201 if(strcmp(new_argv[c], "randomize-train") == 0) {
202 cout << "Generating the training set." << endl;
203 for(int n = 0; n < global.nb_train_samples; n++) {
204 train_labels[n] = int(drand48() * 2);
205 generator->generate(train_labels[n], &train_samples[n]);
209 else if(strcmp(new_argv[c], "adaboost") == 0) {
211 cout << "Building and training adaboost classifier." << endl;
212 classifier = new BoostedClassifier(global.nb_weak_learners);
213 classifier->train(global.nb_train_samples, train_samples, train_labels);
216 else if(strcmp(new_argv[c], "naive-bayesian") == 0) {
218 cout << "Building and training naive bayesian classifier." << endl;
219 classifier = new NaiveBayesianClassifier();
220 classifier->train(global.nb_train_samples, train_samples, train_labels);
223 else if(strcmp(new_argv[c], "read-classifier") == 0) {
225 sprintf(buffer, "%s", global.classifier_name);
226 cout << "Reading classifier from " << buffer << "." << endl;
229 cerr << "Can not open " << buffer << " for reading." << endl;
232 classifier = read_classifier(&in);
235 else if(strcmp(new_argv[c], "write-classifier") == 0) {
236 check(classifier, "No classifier.");
237 sprintf(buffer, "%s/%s", global.result_path, global.classifier_name);
238 cout << "Writing classifier to " << buffer << "." << endl;
239 ofstream out(buffer);
241 cerr << "Can not open " << buffer << " for writing." << endl;
244 classifier->write(&out);
247 else if(strcmp(new_argv[c], "compute-errors-vs-nb-samples") == 0) {
248 for(int t = global.nb_train_samples; t >= 100; t /= 10) {
249 for(int n = 0; n < t; n++) {
250 train_labels[n] = int(drand48() * 2);
251 generator->generate(train_labels[n], &train_samples[n]);
253 Classifier *classifier = 0;
254 cout << "Building and training adaboost classifier with " << t << " samples." << endl;
255 classifier = new BoostedClassifier(global.nb_weak_learners);
256 classifier->train(t, train_samples, train_labels);
257 cout << "ERROR_RATES_VS_NB_SAMPLES "
260 << error_rate(classifier, t, train_samples, train_labels)
262 << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
267 else if(strcmp(new_argv[c], "compute-train-error") == 0) {
268 check(classifier, "No classifier.");
269 cout << "TRAIN_ERROR_RATE "
270 << classifier->name()
272 << error_rate(classifier, global.nb_train_samples, train_samples, train_labels)
276 else if(strcmp(new_argv[c], "compute-test-error") == 0) {
277 check(classifier, "No classifier.");
278 cout << "TEST_ERROR_RATE "
279 << classifier->name()
281 << test_error_rate(generator, classifier, global.nb_test_samples) << endl;
284 else if(strcmp(new_argv[c], "write-samples") == 0) {
286 for(int k = 0; k < global.nb_train_samples; k++) {
287 for(int l = 0; l < 2; l++) {
288 generator->generate(l, &vignette);
289 sprintf(buffer, "%s/sample_%01d_%04d.png", global.result_path, l, k);
290 vignette.write_png(buffer, 1);
291 cout << "Wrote " << buffer << endl;
296 //////////////////////////////////////////////////////////////////////
298 //////////////////////////////////////////////////////////////////////
301 cerr << "Unknown action " << new_argv[c] << endl;
307 cout << "-- FINISHED ----------------------------------------------------------" << endl;
310 delete[] train_labels;
311 delete[] train_samples;