#include "svrt_generator.h"
-THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
+THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) {
struct VignetteSet vs;
+ long nb_vignettes;
long st0, st1, st2;
long v, i, j;
+ long *m, *l;
unsigned char *a, *b;
- svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
- printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
+ nb_vignettes = THLongTensor_size(labels, 0);
+ m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels);
+ st0 = THLongTensor_stride(labels, 0);
+ l = (long *) malloc(sizeof(long) * nb_vignettes);
+ for(v = 0; v < nb_vignettes; v++) {
+ l[v] = *m;
+ m += st0;
+ }
+
+ svrt_generate_vignettes(n_problem, nb_vignettes, l, &vs);
+ free(l);
THLongStorage *size = THLongStorage_newWithSize(3);
size->data[0] = vs.nb_vignettes;
}
}
+ free(vs.data);
+
return result;
}
*
*/
-THByteTensor *generate_vignettes(long n_problem, long nb_images);
+THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels);
unsigned char *data;
};
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+ VignetteSet *result) {
Vignette tmp;
VignetteGenerator *vg = new_generator(n_problem);
unsigned char *s = result->data;
for(int i = 0; i < nb_vignettes; i++) {
- vg->generate(drand48() < 0.5 ? 1 : 0, &tmp);
+ vg->generate(labels[i], &tmp);
int *r = tmp.content;
for(int k = 0; k < Vignette::width * Vignette::height; k++) {
*s++ = *r++;
unsigned char *data;
};
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, struct VignetteSet *result);
+ void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+ struct VignetteSet *result);
#ifdef __cplusplus
}
from _ext import svrt
-train_set = svrt.generate_vignettes(12, 64)
+labels = torch.LongTensor(12).zero_()
+labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+
+train_set = svrt.generate_vignettes(4, labels)
print(str(type(train_set)), train_set.size())
+train_set = train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2))
+
train_set.div_(255)
-torchvision.utils.save_image(train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)), 'example.png')
+torchvision.utils.save_image(train_set, 'example.png')