X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=svrt_generator.cc;h=5d64806aeb75300f9c2c7a51d01277bb83192293;hb=34aeb8100a6c19dae72779f9e46a0acbb5a069c7;hp=90f781d39910722ffd681a7965536512e029b5e0;hpb=24368498f3065e8a4be34c5e8e2b68f9d1220f7d;p=pysvrt.git
diff --git a/svrt_generator.cc b/svrt_generator.cc
index 90f781d..5d64806 100644
--- a/svrt_generator.cc
+++ b/svrt_generator.cc
@@ -18,7 +18,7 @@
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
- * along with selector. If not, see .
+ * along with pysvrt. If not, see .
*
*/
@@ -157,6 +157,11 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
VignetteSet *result) {
Vignette tmp;
+ if(n_problem < 1 || n_problem > NB_PROBLEMS) {
+ printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem);
+ exit(1);
+ }
+
VignetteGenerator *vg = new_generator(n_problem);
result->n_problem = n_problem;
result->nb_vignettes = nb_vignettes;
@@ -166,7 +171,13 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
unsigned char *s = result->data;
for(int i = 0; i < nb_vignettes; i++) {
- vg->generate(labels[i], &tmp);
+ if(labels[i] == 0 || labels[i] == 1) {
+ vg->generate(labels[i], &tmp);
+ } else {
+ printf("Vignette class label has to be 0 or 1. Provided value is %ld.\n", labels[i]);
+ exit(1);
+ }
+
int *r = tmp.content;
for(int k = 0; k < Vignette::width * Vignette::height; k++) {
*s++ = *r++;