Minor update.
[pysvrt.git] / svrt.c
1
2 /*
3  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
4  *  generator for evaluating classification performance of machine
5  *  learning systems, humans and primates.
6  *
7  *  Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
8  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
9  *
10  *  This file is part of svrt.
11  *
12  *  svrt is free software: you can redistribute it and/or modify it
13  *  under the terms of the GNU General Public License version 3 as
14  *  published by the Free Software Foundation.
15  *
16  *  svrt is distributed in the hope that it will be useful, but
17  *  WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  *  General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  */
25
26 #include <TH/TH.h>
27
28 #include "svrt_generator.h"
29
30 THByteStorage *compress(THByteStorage *x) {
31   long k, g, n;
32
33   k = 0; n = 0;
34   while(k < x->size) {
35     g = 0;
36     while(k < x->size && x->data[k] == 255 && g < 255) { g++; k++; }
37     n++;
38     if(k < x->size && g < 255) { k++; }
39   }
40
41   if(x->data[k-1] == 0) {
42     n++;
43   }
44
45   THByteStorage *result = THByteStorage_newWithSize(n);
46
47   k = 0; n = 0;
48   while(k < x->size) {
49     g = 0;
50     while(k < x->size && x->data[k] == 255 && g < 255) { g++; k++; }
51     result->data[n++] = g;
52     if(k < x->size && g < 255) { k++; }
53   }
54   if(x->data[k-1] == 0) {
55     result->data[n++] = 0;
56   }
57
58   return result;
59 }
60
61 THByteStorage *uncompress(THByteStorage *x) {
62   long k, g, n;
63
64   k = 0;
65   for(n = 0; n < x->size - 1; n++) {
66     k = k + x->data[n];
67     if(x->data[n] < 255) { k++; }
68   }
69   k = k + x->data[n];
70
71   THByteStorage *result = THByteStorage_newWithSize(k);
72
73   k = 0;
74   for(n = 0; n < x->size - 1; n++) {
75     for(g = 0; g < x->data[n]; g++) {
76       result->data[k++] = 255;
77     }
78     if(x->data[n] < 255) {
79       result->data[k++] = 0;
80     }
81   }
82   for(g = 0; g < x->data[n]; g++) {
83     result->data[k++] = 255;
84   }
85
86   return result;
87 }
88
89 void seed(long s) {
90   srand48(s);
91 }
92
93 THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) {
94   struct VignetteSet vs;
95   long nb_vignettes;
96   long st0, st1, st2;
97   long v, i, j;
98   long *m, *l;
99   unsigned char *a, *b;
100
101   if(THLongTensor_nDimension(labels) != 1) {
102     printf("Label tensor has to be of dimension 1.\n");
103     exit(1);
104   }
105
106   nb_vignettes = THLongTensor_size(labels, 0);
107   m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels);
108   st0 = THLongTensor_stride(labels, 0);
109   l = (long *) malloc(sizeof(long) * nb_vignettes);
110   for(v = 0; v < nb_vignettes; v++) {
111     l[v] = *m;
112     m += st0;
113   }
114
115   svrt_generate_vignettes(n_problem, nb_vignettes, l, &vs);
116   free(l);
117
118   THLongStorage *size = THLongStorage_newWithSize(3);
119   size->data[0] = vs.nb_vignettes;
120   size->data[1] = vs.height;
121   size->data[2] = vs.width;
122
123   THByteTensor *result = THByteTensor_newWithSize(size, NULL);
124   THLongStorage_free(size);
125
126   st0 = THByteTensor_stride(result, 0);
127   st1 = THByteTensor_stride(result, 1);
128   st2 = THByteTensor_stride(result, 2);
129
130   unsigned char *r = vs.data;
131   for(v = 0; v < vs.nb_vignettes; v++) {
132     a = THByteTensor_storage(result)->data + THByteTensor_storageOffset(result) + v * st0;
133     for(i = 0; i < vs.height; i++) {
134       b = a + i * st1;
135       for(j = 0; j < vs.width; j++) {
136         *b = (unsigned char) (*r);
137         r++;
138         b += st2;
139       }
140     }
141   }
142
143   free(vs.data);
144
145   return result;
146 }