X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=svrt_generator.cc;h=33b98ee905541633632374f08ec3334806f9f3d6;hb=b72b4f27eea7b800e3f1e93af4ce0ae8065d0bc1;hp=82b7c3b154c009e845b97c81ba4fccd5a3195527;hpb=f542d0542b1e51ca7dd12bc6b96f6a299371ae8d;p=pysvrt.git
diff --git a/svrt_generator.cc b/svrt_generator.cc
index 82b7c3b..33b98ee 100644
--- a/svrt_generator.cc
+++ b/svrt_generator.cc
@@ -18,7 +18,7 @@
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
- * along with selector. If not, see .
+ * along with svrt. If not, see .
*
*/
@@ -145,22 +145,46 @@ VignetteGenerator *new_generator(int nb) {
extern "C" {
- struct VignetteSet {
- int n_problem;
- int nb_vignettes;
- int width;
- int height;
- unsigned char *data;
- };
-
- void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
- VignetteGenerator *vg = new_generator(n_problem);
- result->n_problem = n_problem;
- result->nb_vignettes = nb_vignettes;
- result->width = Vignette::width;
- result->height = Vignette::height;
- result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
- delete vg;
+struct VignetteSet {
+ int n_problem;
+ int nb_vignettes;
+ int width;
+ int height;
+ unsigned char *data;
+};
+
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+ VignetteSet *result) {
+ Vignette tmp;
+
+ if(n_problem < 1 || n_problem > NB_PROBLEMS) {
+ printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem);
+ exit(1);
+ }
+
+ VignetteGenerator *vg = new_generator(n_problem);
+ result->n_problem = n_problem;
+ result->nb_vignettes = nb_vignettes;
+ result->width = Vignette::width;
+ result->height = Vignette::height;
+ result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
+
+ unsigned char *s = result->data;
+ for(int i = 0; i < nb_vignettes; i++) {
+ if(labels[i] == 0 || labels[i] == 1) {
+ vg->generate(labels[i], &tmp);
+ } else {
+ printf("Vignette class label has to be 0 or 1. Provided value is %ld.\n", labels[i]);
+ exit(1);
+ }
+
+ int *r = tmp.content;
+ for(int k = 0; k < Vignette::width * Vignette::height; k++) {
+ *s++ = *r++;
+ }
}
+ delete vg;
+}
+
}