projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
OCD cosmetics.
[pysvrt.git]
/
svrt_generator.cc
diff --git
a/svrt_generator.cc
b/svrt_generator.cc
index
80cfd12
..
1fa4f40
100644
(file)
--- a/
svrt_generator.cc
+++ b/
svrt_generator.cc
@@
-153,9
+153,15
@@
struct VignetteSet {
unsigned char *data;
};
unsigned char *data;
};
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+ VignetteSet *result) {
Vignette tmp;
Vignette tmp;
+ if(n_problem < 1 || n_problem > NB_PROBLEMS) {
+ printf("Problem number should be between 1 and %d. Provided value is %d.\n", NB_PROBLEMS, n_problem);
+ exit(1);
+ }
+
VignetteGenerator *vg = new_generator(n_problem);
result->n_problem = n_problem;
result->nb_vignettes = nb_vignettes;
VignetteGenerator *vg = new_generator(n_problem);
result->n_problem = n_problem;
result->nb_vignettes = nb_vignettes;
@@
-165,7
+171,13
@@
void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *resul
unsigned char *s = result->data;
for(int i = 0; i < nb_vignettes; i++) {
unsigned char *s = result->data;
for(int i = 0; i < nb_vignettes; i++) {
- vg->generate(drand48() < 0.5 ? 1 : 0, &tmp);
+ if(labels[i] == 0 || labels[i] == 1) {
+ vg->generate(labels[i], &tmp);
+ } else {
+ printf("Vignette class label has to be 0 or 1. Provided value is %ld.\n", labels[i]);
+ exit(1);
+ }
+
int *r = tmp.content;
for(int k = 0; k < Vignette::width * Vignette::height; k++) {
*s++ = *r++;
int *r = tmp.content;
for(int k = 0; k < Vignette::width * Vignette::height; k++) {
*s++ = *r++;