Minor update.
[pysvrt.git] / svrt_generator.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include <iostream>
26 #include <fstream>
27 #include <cmath>
28 #include <stdio.h>
29 #include <stdlib.h>
30
31 using namespace std;
32
33 #include "random.h"
34
35 #include "vision_problem_1.h"
36 #include "vision_problem_2.h"
37 #include "vision_problem_3.h"
38 #include "vision_problem_4.h"
39 #include "vision_problem_5.h"
40 #include "vision_problem_6.h"
41 #include "vision_problem_7.h"
42 #include "vision_problem_8.h"
43 #include "vision_problem_9.h"
44 #include "vision_problem_10.h"
45 #include "vision_problem_11.h"
46 #include "vision_problem_12.h"
47 #include "vision_problem_13.h"
48 #include "vision_problem_14.h"
49 #include "vision_problem_15.h"
50 #include "vision_problem_16.h"
51 #include "vision_problem_17.h"
52 #include "vision_problem_18.h"
53 #include "vision_problem_19.h"
54 #include "vision_problem_20.h"
55 #include "vision_problem_21.h"
56 #include "vision_problem_22.h"
57 #include "vision_problem_23.h"
58
59 #define NB_PROBLEMS 23
60
61 VignetteGenerator *new_generator(int nb) {
62   VignetteGenerator *generator;
63
64   switch(nb) {
65   case 1:
66     generator = new VisionProblem_1();
67     break;
68   case 2:
69     generator = new VisionProblem_2();
70     break;
71   case 3:
72     generator = new VisionProblem_3();
73     break;
74   case 4:
75     generator = new VisionProblem_4();
76     break;
77   case 5:
78     generator = new VisionProblem_5();
79     break;
80   case 6:
81     generator = new VisionProblem_6();
82     break;
83   case 7:
84     generator = new VisionProblem_7();
85     break;
86   case 8:
87     generator = new VisionProblem_8();
88     break;
89   case 9:
90     generator = new VisionProblem_9();
91     break;
92   case 10:
93     generator = new VisionProblem_10();
94     break;
95   case 11:
96     generator = new VisionProblem_11();
97     break;
98   case 12:
99     generator = new VisionProblem_12();
100     break;
101   case 13:
102     generator = new VisionProblem_13();
103     break;
104   case 14:
105     generator = new VisionProblem_14();
106     break;
107   case 15:
108     generator = new VisionProblem_15();
109     break;
110   case 16:
111     generator = new VisionProblem_16();
112     break;
113   case 17:
114     generator = new VisionProblem_17();
115     break;
116   case 18:
117     generator = new VisionProblem_18();
118     break;
119   case 19:
120     generator = new VisionProblem_19();
121     break;
122   case 20:
123     generator = new VisionProblem_20();
124     break;
125   case 21:
126     generator = new VisionProblem_21();
127     break;
128   case 22:
129     generator = new VisionProblem_22();
130     break;
131   case 23:
132     generator = new VisionProblem_23();
133     break;
134   default:
135     cerr << "Can not find problem "
136          << nb
137          << endl;
138     abort();
139   }
140
141   generator->precompute();
142
143   return generator;
144 }
145
146 extern "C" {
147
148 struct VignetteSet {
149   int n_problem;
150   int nb_vignettes;
151   int width;
152   int height;
153   unsigned char *data;
154 };
155
156 void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
157                              VignetteSet *result) {
158   Vignette tmp;
159
160   if(n_problem < 1 || n_problem > NB_PROBLEMS) {
161     printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem);
162     exit(1);
163   }
164
165   VignetteGenerator *vg = new_generator(n_problem);
166   result->n_problem = n_problem;
167   result->nb_vignettes = nb_vignettes;
168   result->width = Vignette::width;
169   result->height = Vignette::height;
170   result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
171
172   unsigned char *s = result->data;
173   for(int i = 0; i < nb_vignettes; i++) {
174     if(labels[i] == 0 || labels[i] == 1) {
175       vg->generate(labels[i], &tmp);
176     } else {
177       printf("Vignette class label has to be 0 or 1. Provided value is %ld.\n", labels[i]);
178       exit(1);
179     }
180
181     int *r = tmp.content;
182     for(int k = 0; k < Vignette::width * Vignette::height; k++) {
183       *s++ = *r++;
184     }
185   }
186
187   delete vg;
188 }
189
190 }