projects
/
pysvrt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
svrt.generate_vignettes now takes a 1d label tensor as arguments.
[pysvrt.git]
/
svrt_generator.cc
diff --git
a/svrt_generator.cc
b/svrt_generator.cc
index
80cfd12
..
90f781d
100644
(file)
--- a/
svrt_generator.cc
+++ b/
svrt_generator.cc
@@
-153,7
+153,8
@@
struct VignetteSet {
unsigned char *data;
};
unsigned char *data;
};
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+ VignetteSet *result) {
Vignette tmp;
VignetteGenerator *vg = new_generator(n_problem);
Vignette tmp;
VignetteGenerator *vg = new_generator(n_problem);
@@
-165,7
+166,7
@@
void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *resul
unsigned char *s = result->data;
for(int i = 0; i < nb_vignettes; i++) {
unsigned char *s = result->data;
for(int i = 0; i < nb_vignettes; i++) {
- vg->generate(
drand48() < 0.5 ? 1 : 0
, &tmp);
+ vg->generate(
labels[i]
, &tmp);
int *r = tmp.content;
for(int k = 0; k < Vignette::width * Vignette::height; k++) {
*s++ = *r++;
int *r = tmp.content;
for(int k = 0; k < Vignette::width * Vignette::height; k++) {
*s++ = *r++;