Added a bunch of sanity checks.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 19:00:19 +0000 (21:00 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 19:00:19 +0000 (21:00 +0200)
svrt.c
svrt_generator.cc
test-svrt.py

diff --git a/svrt.c b/svrt.c
index 1a2449b..307fcf6 100644 (file)
--- a/svrt.c
+++ b/svrt.c
@@ -35,6 +35,11 @@ THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) {
   long *m, *l;
   unsigned char *a, *b;
 
+  if(THLongTensor_nDimension(labels) != 1) {
+    printf("Label tensor has to be of dimension 1.\n");
+    exit(1);
+  }
+
   nb_vignettes = THLongTensor_size(labels, 0);
   m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels);
   st0 = THLongTensor_stride(labels, 0);
index 90f781d..a536005 100644 (file)
@@ -157,6 +157,11 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
                              VignetteSet *result) {
   Vignette tmp;
 
+  if(n_problem < 1 || n_problem > NB_PROBLEMS) {
+    printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem);
+    exit(1);
+  }
+
   VignetteGenerator *vg = new_generator(n_problem);
   result->n_problem = n_problem;
   result->nb_vignettes = nb_vignettes;
@@ -166,7 +171,13 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
 
   unsigned char *s = result->data;
   for(int i = 0; i < nb_vignettes; i++) {
-    vg->generate(labels[i], &tmp);
+    if(labels[i] == 0 || labels[i] == 1) {
+      vg->generate(labels[i], &tmp);
+    } else {
+      printf("Vignette class label has to be 0 or 1. Provided value is %d.\n", labels[i]);
+      exit(1);
+    }
+
     int *r = tmp.content;
     for(int k = 0; k < Vignette::width * Vignette::height; k++) {
       *s++ = *r++;
index c1309bc..cd98f21 100755 (executable)
@@ -46,3 +46,5 @@ x = x.view(x.size(0), 1, x.size(1), x.size(2))
 x.div_(255)
 
 torchvision.utils.save_image(x, 'example.png')
+
+print('Wrote example.png')