svrt.generate_vignettes now takes a 1d label tensor as arguments.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 16:27:51 +0000 (18:27 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 16:27:51 +0000 (18:27 +0200)
svrt.c
svrt.h
svrt_generator.cc
svrt_generator.h
test-svrt.py

diff --git a/svrt.c b/svrt.c
index fdee66f..1a2449b 100644 (file)
--- a/svrt.c
+++ b/svrt.c
 
 #include "svrt_generator.h"
 
-THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
+THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) {
   struct VignetteSet vs;
+  long nb_vignettes;
   long st0, st1, st2;
   long v, i, j;
+  long *m, *l;
   unsigned char *a, *b;
 
-  svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
-  printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
+  nb_vignettes = THLongTensor_size(labels, 0);
+  m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels);
+  st0 = THLongTensor_stride(labels, 0);
+  l = (long *) malloc(sizeof(long) * nb_vignettes);
+  for(v = 0; v < nb_vignettes; v++) {
+    l[v] = *m;
+    m += st0;
+  }
+
+  svrt_generate_vignettes(n_problem, nb_vignettes, l, &vs);
+  free(l);
 
   THLongStorage *size = THLongStorage_newWithSize(3);
   size->data[0] = vs.nb_vignettes;
@@ -61,5 +72,7 @@ THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
     }
   }
 
+  free(vs.data);
+
   return result;
 }
diff --git a/svrt.h b/svrt.h
index 4335df3..77b8b46 100644 (file)
--- a/svrt.h
+++ b/svrt.h
@@ -23,4 +23,4 @@
  *
  */
 
-THByteTensor *generate_vignettes(long n_problem, long nb_images);
+THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels);
index 80cfd12..90f781d 100644 (file)
@@ -153,7 +153,8 @@ struct VignetteSet {
   unsigned char *data;
 };
 
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+                             VignetteSet *result) {
   Vignette tmp;
 
   VignetteGenerator *vg = new_generator(n_problem);
@@ -165,7 +166,7 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *resul
 
   unsigned char *s = result->data;
   for(int i = 0; i < nb_vignettes; i++) {
-    vg->generate(drand48() < 0.5 ? 1 : 0, &tmp);
+    vg->generate(labels[i], &tmp);
     int *r = tmp.content;
     for(int k = 0; k < Vignette::width * Vignette::height; k++) {
       *s++ = *r++;
index bdfe5c1..7f6a3ad 100644 (file)
@@ -35,7 +35,8 @@ struct VignetteSet {
   unsigned char *data;
 };
 
-void svrt_generate_vignettes(int n_problem, int nb_vignettes, struct VignetteSet *result);
+  void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels,
+                               struct VignetteSet *result);
 
 #ifdef __cplusplus
 }
index 6b5f826..9aa2d59 100755 (executable)
@@ -36,10 +36,15 @@ from torchvision import datasets, transforms, utils
 
 from _ext import svrt
 
-train_set = svrt.generate_vignettes(12, 64)
+labels = torch.LongTensor(12).zero_()
+labels.narrow(0, 0, labels.size(0)//2).fill_(1)
+
+train_set = svrt.generate_vignettes(4, labels)
 
 print(str(type(train_set)), train_set.size())
 
+train_set = train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2))
+
 train_set.div_(255)
 
-torchvision.utils.save_image(train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)), 'example.png')
+torchvision.utils.save_image(train_set, 'example.png')