Started to implement the gravity by making the pulling optional.
[flatland.git] / flatland.c
1 /*
2
3   Example of FFI extension I started from:
4
5     https://github.com/pytorch/extension-ffi.git
6
7   There is this tutorial
8
9     https://github.com/pytorch/tutorials/blob/master/Creating%20Extensions%20using%20FFI.md
10
11   And TH's Tensor definition are here in my install:
12
13     anaconda3/lib/python3.5/site-packages/torch/lib/include/TH/generic/THTensor.h
14
15  */
16
17 #include <TH/TH.h>
18
19 #include "sequence_generator.h"
20
21 THByteTensor *generate_sequence(int pulling,
22                                 long nb_sequences,
23                                 long nb_images,
24                                 long image_height, long image_width,
25                                 long nb_shapes,
26                                 int random_shape_size, int random_colors) {
27
28   long nb_channels = 3;
29   unsigned char *a, *b;
30   long s, c, k, i, j, st0, st1, st2, st3, st4;
31
32   THLongStorage *size = THLongStorage_newWithSize(5);
33   size->data[0] = nb_sequences;
34   size->data[1] = nb_images;
35   size->data[2] = nb_channels;
36   size->data[3] = image_height;
37   size->data[4] = image_width;
38   THByteTensor *result = THByteTensor_newWithSize(size, NULL);
39   THLongStorage_free(size);
40
41   st0 = THByteTensor_stride(result, 0);
42   st1 = THByteTensor_stride(result, 1);
43   st2 = THByteTensor_stride(result, 2);
44   st3 = THByteTensor_stride(result, 3);
45   st4 = THByteTensor_stride(result, 4);
46
47   unsigned char tmp_buffer[nb_images * nb_channels * image_width * image_height];
48
49   for(s = 0; s < nb_sequences; s++) {
50     a = THByteTensor_storage(result)->data + THByteTensor_storageOffset(result) + s * st0;
51     fl_generate_sequence(nb_images, image_width, image_height, nb_shapes,
52                          random_shape_size, random_colors,
53                          pulling,
54                          tmp_buffer);
55     unsigned char *r = tmp_buffer;
56     for(k = 0; k < nb_images; k++) {
57       for(c = 0; c < nb_channels; c++) {
58         for(i = 0; i < image_height; i++) {
59           b = a + k * st1 + c * st2 + i * st3;
60           for(j = 0; j < image_width; j++) {
61             *b = (unsigned char) (*r);
62             r++;
63             b += st4;
64           }
65         }
66       }
67     }
68   }
69
70   return result;
71 }