Test now saves an example image.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 16:06:35 +0000 (18:06 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 14 Jun 2017 16:06:35 +0000 (18:06 +0200)
svrt.c
svrt_generator.cc
test-svrt.py

diff --git a/svrt.c b/svrt.c
index 0f53642..fdee66f 100644 (file)
--- a/svrt.c
+++ b/svrt.c
 
 THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) {
   struct VignetteSet vs;
+  long st0, st1, st2;
+  long v, i, j;
+  unsigned char *a, *b;
 
   svrt_generate_vignettes(n_problem, nb_vignettes, &vs);
   printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height);
 
   THLongStorage *size = THLongStorage_newWithSize(3);
-  size->data[0] = nb_vignettes;
+  size->data[0] = vs.nb_vignettes;
   size->data[1] = vs.height;
   size->data[2] = vs.width;
 
   THByteTensor *result = THByteTensor_newWithSize(size, NULL);
   THLongStorage_free(size);
 
-  /* st0 = THByteTensor_stride(result, 0); */
-  /* st1 = THByteTensor_stride(result, 1); */
-  /* st2 = THByteTensor_stride(result, 2); */
+  st0 = THByteTensor_stride(result, 0);
+  st1 = THByteTensor_stride(result, 1);
+  st2 = THByteTensor_stride(result, 2);
+
+  unsigned char *r = vs.data;
+  for(v = 0; v < vs.nb_vignettes; v++) {
+    a = THByteTensor_storage(result)->data + THByteTensor_storageOffset(result) + v * st0;
+    for(i = 0; i < vs.height; i++) {
+      b = a + i * st1;
+      for(j = 0; j < vs.width; j++) {
+        *b = (unsigned char) (*r);
+        r++;
+        b += st2;
+      }
+    }
+  }
 
   return result;
 }
index 82b7c3b..80cfd12 100644 (file)
@@ -145,22 +145,34 @@ VignetteGenerator *new_generator(int nb) {
 
 extern "C" {
 
-  struct VignetteSet {
-    int n_problem;
-    int nb_vignettes;
-    int width;
-    int height;
-    unsigned char *data;
-  };
-
-  void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
-    VignetteGenerator *vg = new_generator(n_problem);
-    result->n_problem = n_problem;
-    result->nb_vignettes = nb_vignettes;
-    result->width = Vignette::width;
-    result->height = Vignette::height;
-    result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
-    delete vg;
+struct VignetteSet {
+  int n_problem;
+  int nb_vignettes;
+  int width;
+  int height;
+  unsigned char *data;
+};
+
+void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) {
+  Vignette tmp;
+
+  VignetteGenerator *vg = new_generator(n_problem);
+  result->n_problem = n_problem;
+  result->nb_vignettes = nb_vignettes;
+  result->width = Vignette::width;
+  result->height = Vignette::height;
+  result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height);
+
+  unsigned char *s = result->data;
+  for(int i = 0; i < nb_vignettes; i++) {
+    vg->generate(drand48() < 0.5 ? 1 : 0, &tmp);
+    int *r = tmp.content;
+    for(int k = 0; k < Vignette::width * Vignette::height; k++) {
+      *s++ = *r++;
+    }
   }
 
+  delete vg;
+}
+
 }
index 92fc554..6b5f826 100755 (executable)
 import time
 
 import torch
+import torchvision
 
 from torch import optim
 from torch import FloatTensor as Tensor
 from torch.autograd import Variable
 from torch import nn
 from torch.nn import functional as fn
+
 from torchvision import datasets, transforms, utils
 
 from _ext import svrt
 
-train_set = svrt.generate_vignettes(12, 1234)
+train_set = svrt.generate_vignettes(12, 64)
 
 print(str(type(train_set)), train_set.size())
+
+train_set.div_(255)
+
+torchvision.utils.save_image(train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)), 'example.png')