From dfca7e16051d1752db7daed892ecb200237e3bb6 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 14 Jun 2017 21:00:19 +0200 Subject: [PATCH] Added a bunch of sanity checks. --- svrt.c | 5 +++++ svrt_generator.cc | 13 ++++++++++++- test-svrt.py | 2 ++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/svrt.c b/svrt.c index 1a2449b..307fcf6 100644 --- a/svrt.c +++ b/svrt.c @@ -35,6 +35,11 @@ THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) { long *m, *l; unsigned char *a, *b; + if(THLongTensor_nDimension(labels) != 1) { + printf("Label tensor has to be of dimension 1.\n"); + exit(1); + } + nb_vignettes = THLongTensor_size(labels, 0); m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels); st0 = THLongTensor_stride(labels, 0); diff --git a/svrt_generator.cc b/svrt_generator.cc index 90f781d..a536005 100644 --- a/svrt_generator.cc +++ b/svrt_generator.cc @@ -157,6 +157,11 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels, VignetteSet *result) { Vignette tmp; + if(n_problem < 1 || n_problem > NB_PROBLEMS) { + printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem); + exit(1); + } + VignetteGenerator *vg = new_generator(n_problem); result->n_problem = n_problem; result->nb_vignettes = nb_vignettes; @@ -166,7 +171,13 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels, unsigned char *s = result->data; for(int i = 0; i < nb_vignettes; i++) { - vg->generate(labels[i], &tmp); + if(labels[i] == 0 || labels[i] == 1) { + vg->generate(labels[i], &tmp); + } else { + printf("Vignette class label has to be 0 or 1. Provided value is %d.\n", labels[i]); + exit(1); + } + int *r = tmp.content; for(int k = 0; k < Vignette::width * Vignette::height; k++) { *s++ = *r++; diff --git a/test-svrt.py b/test-svrt.py index c1309bc..cd98f21 100755 --- a/test-svrt.py +++ b/test-svrt.py @@ -46,3 +46,5 @@ x = x.view(x.size(0), 1, x.size(1), x.size(2)) x.div_(255) torchvision.utils.save_image(x, 'example.png') + +print('Wrote example.png') -- 2.39.5