Minor fixes + changed the default number of epochs to 100.
[pysvrt.git] / svrt_generator.cc
index 90f781d..1fa4f40 100644 (file)
@@ -157,6 +157,11 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
                              VignetteSet *result) {
   Vignette tmp;
 
+  if(n_problem < 1 || n_problem > NB_PROBLEMS) {
+    printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem);
+    exit(1);
+  }
+
   VignetteGenerator *vg = new_generator(n_problem);
   result->n_problem = n_problem;
   result->nb_vignettes = nb_vignettes;
@@ -166,7 +171,13 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
 
   unsigned char *s = result->data;
   for(int i = 0; i < nb_vignettes; i++) {
-    vg->generate(labels[i], &tmp);
+    if(labels[i] == 0 || labels[i] == 1) {
+      vg->generate(labels[i], &tmp);
+    } else {
+      printf("Vignette class label has to be 0 or 1. Provided value is %ld.\n", labels[i]);
+      exit(1);
+    }
+
     int *r = tmp.content;
     for(int k = 0; k < Vignette::width * Vignette::height; k++) {
       *s++ = *r++;